jnkr36 commited on
Commit
ca22ec0
·
1 Parent(s): 4bb060d

Upload 43 files

Browse files
Files changed (43) hide show
  1. comfy_extras/chainner_models/__init__.py +0 -0
  2. comfy_extras/chainner_models/architecture/HAT.py +1277 -0
  3. comfy_extras/chainner_models/architecture/LICENSE-ESRGAN +201 -0
  4. comfy_extras/chainner_models/architecture/LICENSE-HAT +21 -0
  5. comfy_extras/chainner_models/architecture/LICENSE-RealESRGAN +29 -0
  6. comfy_extras/chainner_models/architecture/LICENSE-SPSR +201 -0
  7. comfy_extras/chainner_models/architecture/LICENSE-SwiftSRGAN +121 -0
  8. comfy_extras/chainner_models/architecture/LICENSE-Swin2SR +201 -0
  9. comfy_extras/chainner_models/architecture/LICENSE-SwinIR +201 -0
  10. comfy_extras/chainner_models/architecture/LICENSE-lama +201 -0
  11. comfy_extras/chainner_models/architecture/LICENSE-mat +161 -0
  12. comfy_extras/chainner_models/architecture/LaMa.py +694 -0
  13. comfy_extras/chainner_models/architecture/MAT.py +1636 -0
  14. comfy_extras/chainner_models/architecture/RRDB.py +281 -0
  15. comfy_extras/chainner_models/architecture/SPSR.py +384 -0
  16. comfy_extras/chainner_models/architecture/SRVGG.py +114 -0
  17. comfy_extras/chainner_models/architecture/SwiftSRGAN.py +161 -0
  18. comfy_extras/chainner_models/architecture/Swin2SR.py +1377 -0
  19. comfy_extras/chainner_models/architecture/SwinIR.py +1208 -0
  20. comfy_extras/chainner_models/architecture/__init__.py +0 -0
  21. comfy_extras/chainner_models/architecture/block.py +513 -0
  22. comfy_extras/chainner_models/architecture/face/LICENSE-GFPGAN +351 -0
  23. comfy_extras/chainner_models/architecture/face/LICENSE-RestoreFormer +351 -0
  24. comfy_extras/chainner_models/architecture/face/LICENSE-codeformer +35 -0
  25. comfy_extras/chainner_models/architecture/face/arcface_arch.py +265 -0
  26. comfy_extras/chainner_models/architecture/face/codeformer.py +790 -0
  27. comfy_extras/chainner_models/architecture/face/fused_act.py +81 -0
  28. comfy_extras/chainner_models/architecture/face/gfpgan_bilinear_arch.py +389 -0
  29. comfy_extras/chainner_models/architecture/face/gfpganv1_arch.py +566 -0
  30. comfy_extras/chainner_models/architecture/face/gfpganv1_clean_arch.py +370 -0
  31. comfy_extras/chainner_models/architecture/face/restoreformer_arch.py +776 -0
  32. comfy_extras/chainner_models/architecture/face/stylegan2_arch.py +865 -0
  33. comfy_extras/chainner_models/architecture/face/stylegan2_bilinear_arch.py +709 -0
  34. comfy_extras/chainner_models/architecture/face/stylegan2_clean_arch.py +453 -0
  35. comfy_extras/chainner_models/architecture/face/upfirdn2d.py +194 -0
  36. comfy_extras/chainner_models/architecture/mat/utils.py +698 -0
  37. comfy_extras/chainner_models/architecture/timm/LICENSE +201 -0
  38. comfy_extras/chainner_models/architecture/timm/drop.py +223 -0
  39. comfy_extras/chainner_models/architecture/timm/helpers.py +31 -0
  40. comfy_extras/chainner_models/architecture/timm/weight_init.py +128 -0
  41. comfy_extras/chainner_models/model_loading.py +89 -0
  42. comfy_extras/chainner_models/types.py +53 -0
  43. comfy_extras/nodes_upscale_model.py +48 -0
comfy_extras/chainner_models/__init__.py ADDED
File without changes
comfy_extras/chainner_models/architecture/HAT.py ADDED
@@ -0,0 +1,1277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # HAT from https://github.com/XPixelGroup/HAT/blob/main/hat/archs/hat_arch.py
3
+ import math
4
+ import re
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+ from .timm.helpers import to_2tuple
12
+ from .timm.weight_init import trunc_normal_
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
17
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
18
+ """
19
+ if drop_prob == 0.0 or not training:
20
+ return x
21
+ keep_prob = 1 - drop_prob
22
+ shape = (x.shape[0],) + (1,) * (
23
+ x.ndim - 1
24
+ ) # work with diff dim tensors, not just 2D ConvNets
25
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
26
+ random_tensor.floor_() # binarize
27
+ output = x.div(keep_prob) * random_tensor
28
+ return output
29
+
30
+
31
+ class DropPath(nn.Module):
32
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
33
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
34
+ """
35
+
36
+ def __init__(self, drop_prob=None):
37
+ super(DropPath, self).__init__()
38
+ self.drop_prob = drop_prob
39
+
40
+ def forward(self, x):
41
+ return drop_path(x, self.drop_prob, self.training) # type: ignore
42
+
43
+
44
+ class ChannelAttention(nn.Module):
45
+ """Channel attention used in RCAN.
46
+ Args:
47
+ num_feat (int): Channel number of intermediate features.
48
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
49
+ """
50
+
51
+ def __init__(self, num_feat, squeeze_factor=16):
52
+ super(ChannelAttention, self).__init__()
53
+ self.attention = nn.Sequential(
54
+ nn.AdaptiveAvgPool2d(1),
55
+ nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
58
+ nn.Sigmoid(),
59
+ )
60
+
61
+ def forward(self, x):
62
+ y = self.attention(x)
63
+ return x * y
64
+
65
+
66
+ class CAB(nn.Module):
67
+ def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
68
+ super(CAB, self).__init__()
69
+
70
+ self.cab = nn.Sequential(
71
+ nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
72
+ nn.GELU(),
73
+ nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
74
+ ChannelAttention(num_feat, squeeze_factor),
75
+ )
76
+
77
+ def forward(self, x):
78
+ return self.cab(x)
79
+
80
+
81
+ class Mlp(nn.Module):
82
+ def __init__(
83
+ self,
84
+ in_features,
85
+ hidden_features=None,
86
+ out_features=None,
87
+ act_layer=nn.GELU,
88
+ drop=0.0,
89
+ ):
90
+ super().__init__()
91
+ out_features = out_features or in_features
92
+ hidden_features = hidden_features or in_features
93
+ self.fc1 = nn.Linear(in_features, hidden_features)
94
+ self.act = act_layer()
95
+ self.fc2 = nn.Linear(hidden_features, out_features)
96
+ self.drop = nn.Dropout(drop)
97
+
98
+ def forward(self, x):
99
+ x = self.fc1(x)
100
+ x = self.act(x)
101
+ x = self.drop(x)
102
+ x = self.fc2(x)
103
+ x = self.drop(x)
104
+ return x
105
+
106
+
107
+ def window_partition(x, window_size):
108
+ """
109
+ Args:
110
+ x: (b, h, w, c)
111
+ window_size (int): window size
112
+ Returns:
113
+ windows: (num_windows*b, window_size, window_size, c)
114
+ """
115
+ b, h, w, c = x.shape
116
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
117
+ windows = (
118
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
119
+ )
120
+ return windows
121
+
122
+
123
+ def window_reverse(windows, window_size, h, w):
124
+ """
125
+ Args:
126
+ windows: (num_windows*b, window_size, window_size, c)
127
+ window_size (int): Window size
128
+ h (int): Height of image
129
+ w (int): Width of image
130
+ Returns:
131
+ x: (b, h, w, c)
132
+ """
133
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
134
+ x = windows.view(
135
+ b, h // window_size, w // window_size, window_size, window_size, -1
136
+ )
137
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
138
+ return x
139
+
140
+
141
+ class WindowAttention(nn.Module):
142
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
143
+ It supports both of shifted and non-shifted window.
144
+ Args:
145
+ dim (int): Number of input channels.
146
+ window_size (tuple[int]): The height and width of the window.
147
+ num_heads (int): Number of attention heads.
148
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
149
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
150
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
151
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ dim,
157
+ window_size,
158
+ num_heads,
159
+ qkv_bias=True,
160
+ qk_scale=None,
161
+ attn_drop=0.0,
162
+ proj_drop=0.0,
163
+ ):
164
+ super().__init__()
165
+ self.dim = dim
166
+ self.window_size = window_size # Wh, Ww
167
+ self.num_heads = num_heads
168
+ head_dim = dim // num_heads
169
+ self.scale = qk_scale or head_dim**-0.5
170
+
171
+ # define a parameter table of relative position bias
172
+ self.relative_position_bias_table = nn.Parameter( # type: ignore
173
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
174
+ ) # 2*Wh-1 * 2*Ww-1, nH
175
+
176
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
177
+ self.attn_drop = nn.Dropout(attn_drop)
178
+ self.proj = nn.Linear(dim, dim)
179
+
180
+ self.proj_drop = nn.Dropout(proj_drop)
181
+
182
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
183
+ self.softmax = nn.Softmax(dim=-1)
184
+
185
+ def forward(self, x, rpi, mask=None):
186
+ """
187
+ Args:
188
+ x: input features with shape of (num_windows*b, n, c)
189
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
190
+ """
191
+ b_, n, c = x.shape
192
+ qkv = (
193
+ self.qkv(x)
194
+ .reshape(b_, n, 3, self.num_heads, c // self.num_heads)
195
+ .permute(2, 0, 3, 1, 4)
196
+ )
197
+ q, k, v = (
198
+ qkv[0],
199
+ qkv[1],
200
+ qkv[2],
201
+ ) # make torchscript happy (cannot use tensor as tuple)
202
+
203
+ q = q * self.scale
204
+ attn = q @ k.transpose(-2, -1)
205
+
206
+ relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
207
+ self.window_size[0] * self.window_size[1],
208
+ self.window_size[0] * self.window_size[1],
209
+ -1,
210
+ ) # Wh*Ww,Wh*Ww,nH
211
+ relative_position_bias = relative_position_bias.permute(
212
+ 2, 0, 1
213
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
214
+ attn = attn + relative_position_bias.unsqueeze(0)
215
+
216
+ if mask is not None:
217
+ nw = mask.shape[0]
218
+ attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(
219
+ 1
220
+ ).unsqueeze(0)
221
+ attn = attn.view(-1, self.num_heads, n, n)
222
+ attn = self.softmax(attn)
223
+ else:
224
+ attn = self.softmax(attn)
225
+
226
+ attn = self.attn_drop(attn)
227
+
228
+ x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
229
+ x = self.proj(x)
230
+ x = self.proj_drop(x)
231
+ return x
232
+
233
+
234
+ class HAB(nn.Module):
235
+ r"""Hybrid Attention Block.
236
+ Args:
237
+ dim (int): Number of input channels.
238
+ input_resolution (tuple[int]): Input resolution.
239
+ num_heads (int): Number of attention heads.
240
+ window_size (int): Window size.
241
+ shift_size (int): Shift size for SW-MSA.
242
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
243
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
244
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
245
+ drop (float, optional): Dropout rate. Default: 0.0
246
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
247
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
248
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
249
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ dim,
255
+ input_resolution,
256
+ num_heads,
257
+ window_size=7,
258
+ shift_size=0,
259
+ compress_ratio=3,
260
+ squeeze_factor=30,
261
+ conv_scale=0.01,
262
+ mlp_ratio=4.0,
263
+ qkv_bias=True,
264
+ qk_scale=None,
265
+ drop=0.0,
266
+ attn_drop=0.0,
267
+ drop_path=0.0,
268
+ act_layer=nn.GELU,
269
+ norm_layer=nn.LayerNorm,
270
+ ):
271
+ super().__init__()
272
+ self.dim = dim
273
+ self.input_resolution = input_resolution
274
+ self.num_heads = num_heads
275
+ self.window_size = window_size
276
+ self.shift_size = shift_size
277
+ self.mlp_ratio = mlp_ratio
278
+ if min(self.input_resolution) <= self.window_size:
279
+ # if window size is larger than input resolution, we don't partition windows
280
+ self.shift_size = 0
281
+ self.window_size = min(self.input_resolution)
282
+ assert (
283
+ 0 <= self.shift_size < self.window_size
284
+ ), "shift_size must in 0-window_size"
285
+
286
+ self.norm1 = norm_layer(dim)
287
+ self.attn = WindowAttention(
288
+ dim,
289
+ window_size=to_2tuple(self.window_size),
290
+ num_heads=num_heads,
291
+ qkv_bias=qkv_bias,
292
+ qk_scale=qk_scale,
293
+ attn_drop=attn_drop,
294
+ proj_drop=drop,
295
+ )
296
+
297
+ self.conv_scale = conv_scale
298
+ self.conv_block = CAB(
299
+ num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor
300
+ )
301
+
302
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
303
+ self.norm2 = norm_layer(dim)
304
+ mlp_hidden_dim = int(dim * mlp_ratio)
305
+ self.mlp = Mlp(
306
+ in_features=dim,
307
+ hidden_features=mlp_hidden_dim,
308
+ act_layer=act_layer,
309
+ drop=drop,
310
+ )
311
+
312
+ def forward(self, x, x_size, rpi_sa, attn_mask):
313
+ h, w = x_size
314
+ b, _, c = x.shape
315
+ # assert seq_len == h * w, "input feature has wrong size"
316
+
317
+ shortcut = x
318
+ x = self.norm1(x)
319
+ x = x.view(b, h, w, c)
320
+
321
+ # Conv_X
322
+ conv_x = self.conv_block(x.permute(0, 3, 1, 2))
323
+ conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
324
+
325
+ # cyclic shift
326
+ if self.shift_size > 0:
327
+ shifted_x = torch.roll(
328
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
329
+ )
330
+ attn_mask = attn_mask
331
+ else:
332
+ shifted_x = x
333
+ attn_mask = None
334
+
335
+ # partition windows
336
+ x_windows = window_partition(
337
+ shifted_x, self.window_size
338
+ ) # nw*b, window_size, window_size, c
339
+ x_windows = x_windows.view(
340
+ -1, self.window_size * self.window_size, c
341
+ ) # nw*b, window_size*window_size, c
342
+
343
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
344
+ attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
345
+
346
+ # merge windows
347
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
348
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
349
+
350
+ # reverse cyclic shift
351
+ if self.shift_size > 0:
352
+ attn_x = torch.roll(
353
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
354
+ )
355
+ else:
356
+ attn_x = shifted_x
357
+ attn_x = attn_x.view(b, h * w, c)
358
+
359
+ # FFN
360
+ x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
361
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
362
+
363
+ return x
364
+
365
+
366
+ class PatchMerging(nn.Module):
367
+ r"""Patch Merging Layer.
368
+ Args:
369
+ input_resolution (tuple[int]): Resolution of input feature.
370
+ dim (int): Number of input channels.
371
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
372
+ """
373
+
374
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
375
+ super().__init__()
376
+ self.input_resolution = input_resolution
377
+ self.dim = dim
378
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
379
+ self.norm = norm_layer(4 * dim)
380
+
381
+ def forward(self, x):
382
+ """
383
+ x: b, h*w, c
384
+ """
385
+ h, w = self.input_resolution
386
+ b, seq_len, c = x.shape
387
+ assert seq_len == h * w, "input feature has wrong size"
388
+ assert h % 2 == 0 and w % 2 == 0, f"x size ({h}*{w}) are not even."
389
+
390
+ x = x.view(b, h, w, c)
391
+
392
+ x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
393
+ x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
394
+ x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
395
+ x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
396
+ x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
397
+ x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
398
+
399
+ x = self.norm(x)
400
+ x = self.reduction(x)
401
+
402
+ return x
403
+
404
+
405
+ class OCAB(nn.Module):
406
+ # overlapping cross-attention block
407
+
408
+ def __init__(
409
+ self,
410
+ dim,
411
+ input_resolution,
412
+ window_size,
413
+ overlap_ratio,
414
+ num_heads,
415
+ qkv_bias=True,
416
+ qk_scale=None,
417
+ mlp_ratio=2,
418
+ norm_layer=nn.LayerNorm,
419
+ ):
420
+ super().__init__()
421
+ self.dim = dim
422
+ self.input_resolution = input_resolution
423
+ self.window_size = window_size
424
+ self.num_heads = num_heads
425
+ head_dim = dim // num_heads
426
+ self.scale = qk_scale or head_dim**-0.5
427
+ self.overlap_win_size = int(window_size * overlap_ratio) + window_size
428
+
429
+ self.norm1 = norm_layer(dim)
430
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
431
+ self.unfold = nn.Unfold(
432
+ kernel_size=(self.overlap_win_size, self.overlap_win_size),
433
+ stride=window_size,
434
+ padding=(self.overlap_win_size - window_size) // 2,
435
+ )
436
+
437
+ # define a parameter table of relative position bias
438
+ self.relative_position_bias_table = nn.Parameter( # type: ignore
439
+ torch.zeros(
440
+ (window_size + self.overlap_win_size - 1)
441
+ * (window_size + self.overlap_win_size - 1),
442
+ num_heads,
443
+ )
444
+ ) # 2*Wh-1 * 2*Ww-1, nH
445
+
446
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
447
+ self.softmax = nn.Softmax(dim=-1)
448
+
449
+ self.proj = nn.Linear(dim, dim)
450
+
451
+ self.norm2 = norm_layer(dim)
452
+ mlp_hidden_dim = int(dim * mlp_ratio)
453
+ self.mlp = Mlp(
454
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU
455
+ )
456
+
457
+ def forward(self, x, x_size, rpi):
458
+ h, w = x_size
459
+ b, _, c = x.shape
460
+
461
+ shortcut = x
462
+ x = self.norm1(x)
463
+ x = x.view(b, h, w, c)
464
+
465
+ qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2) # 3, b, c, h, w
466
+ q = qkv[0].permute(0, 2, 3, 1) # b, h, w, c
467
+ kv = torch.cat((qkv[1], qkv[2]), dim=1) # b, 2*c, h, w
468
+
469
+ # partition windows
470
+ q_windows = window_partition(
471
+ q, self.window_size
472
+ ) # nw*b, window_size, window_size, c
473
+ q_windows = q_windows.view(
474
+ -1, self.window_size * self.window_size, c
475
+ ) # nw*b, window_size*window_size, c
476
+
477
+ kv_windows = self.unfold(kv) # b, c*w*w, nw
478
+ kv_windows = rearrange(
479
+ kv_windows,
480
+ "b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch",
481
+ nc=2,
482
+ ch=c,
483
+ owh=self.overlap_win_size,
484
+ oww=self.overlap_win_size,
485
+ ).contiguous() # 2, nw*b, ow*ow, c
486
+ # Do the above rearrangement without the rearrange function
487
+ # kv_windows = kv_windows.view(
488
+ # 2, b, self.overlap_win_size, self.overlap_win_size, c, -1
489
+ # )
490
+ # kv_windows = kv_windows.permute(0, 5, 1, 2, 3, 4).contiguous()
491
+ # kv_windows = kv_windows.view(
492
+ # 2, -1, self.overlap_win_size * self.overlap_win_size, c
493
+ # )
494
+
495
+ k_windows, v_windows = kv_windows[0], kv_windows[1] # nw*b, ow*ow, c
496
+
497
+ b_, nq, _ = q_windows.shape
498
+ _, n, _ = k_windows.shape
499
+ d = self.dim // self.num_heads
500
+ q = q_windows.reshape(b_, nq, self.num_heads, d).permute(
501
+ 0, 2, 1, 3
502
+ ) # nw*b, nH, nq, d
503
+ k = k_windows.reshape(b_, n, self.num_heads, d).permute(
504
+ 0, 2, 1, 3
505
+ ) # nw*b, nH, n, d
506
+ v = v_windows.reshape(b_, n, self.num_heads, d).permute(
507
+ 0, 2, 1, 3
508
+ ) # nw*b, nH, n, d
509
+
510
+ q = q * self.scale
511
+ attn = q @ k.transpose(-2, -1)
512
+
513
+ relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
514
+ self.window_size * self.window_size,
515
+ self.overlap_win_size * self.overlap_win_size,
516
+ -1,
517
+ ) # ws*ws, wse*wse, nH
518
+ relative_position_bias = relative_position_bias.permute(
519
+ 2, 0, 1
520
+ ).contiguous() # nH, ws*ws, wse*wse
521
+ attn = attn + relative_position_bias.unsqueeze(0)
522
+
523
+ attn = self.softmax(attn)
524
+ attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim)
525
+
526
+ # merge windows
527
+ attn_windows = attn_windows.view(
528
+ -1, self.window_size, self.window_size, self.dim
529
+ )
530
+ x = window_reverse(attn_windows, self.window_size, h, w) # b h w c
531
+ x = x.view(b, h * w, self.dim)
532
+
533
+ x = self.proj(x) + shortcut
534
+
535
+ x = x + self.mlp(self.norm2(x))
536
+ return x
537
+
538
+
539
+ class AttenBlocks(nn.Module):
540
+ """A series of attention blocks for one RHAG.
541
+ Args:
542
+ dim (int): Number of input channels.
543
+ input_resolution (tuple[int]): Input resolution.
544
+ depth (int): Number of blocks.
545
+ num_heads (int): Number of attention heads.
546
+ window_size (int): Local window size.
547
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
548
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
549
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
550
+ drop (float, optional): Dropout rate. Default: 0.0
551
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
552
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
553
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
554
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
555
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
556
+ """
557
+
558
+ def __init__(
559
+ self,
560
+ dim,
561
+ input_resolution,
562
+ depth,
563
+ num_heads,
564
+ window_size,
565
+ compress_ratio,
566
+ squeeze_factor,
567
+ conv_scale,
568
+ overlap_ratio,
569
+ mlp_ratio=4.0,
570
+ qkv_bias=True,
571
+ qk_scale=None,
572
+ drop=0.0,
573
+ attn_drop=0.0,
574
+ drop_path=0.0,
575
+ norm_layer=nn.LayerNorm,
576
+ downsample=None,
577
+ use_checkpoint=False,
578
+ ):
579
+ super().__init__()
580
+ self.dim = dim
581
+ self.input_resolution = input_resolution
582
+ self.depth = depth
583
+ self.use_checkpoint = use_checkpoint
584
+
585
+ # build blocks
586
+ self.blocks = nn.ModuleList(
587
+ [
588
+ HAB(
589
+ dim=dim,
590
+ input_resolution=input_resolution,
591
+ num_heads=num_heads,
592
+ window_size=window_size,
593
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
594
+ compress_ratio=compress_ratio,
595
+ squeeze_factor=squeeze_factor,
596
+ conv_scale=conv_scale,
597
+ mlp_ratio=mlp_ratio,
598
+ qkv_bias=qkv_bias,
599
+ qk_scale=qk_scale,
600
+ drop=drop,
601
+ attn_drop=attn_drop,
602
+ drop_path=drop_path[i]
603
+ if isinstance(drop_path, list)
604
+ else drop_path,
605
+ norm_layer=norm_layer,
606
+ )
607
+ for i in range(depth)
608
+ ]
609
+ )
610
+
611
+ # OCAB
612
+ self.overlap_attn = OCAB(
613
+ dim=dim,
614
+ input_resolution=input_resolution,
615
+ window_size=window_size,
616
+ overlap_ratio=overlap_ratio,
617
+ num_heads=num_heads,
618
+ qkv_bias=qkv_bias,
619
+ qk_scale=qk_scale,
620
+ mlp_ratio=mlp_ratio, # type: ignore
621
+ norm_layer=norm_layer,
622
+ )
623
+
624
+ # patch merging layer
625
+ if downsample is not None:
626
+ self.downsample = downsample(
627
+ input_resolution, dim=dim, norm_layer=norm_layer
628
+ )
629
+ else:
630
+ self.downsample = None
631
+
632
+ def forward(self, x, x_size, params):
633
+ for blk in self.blocks:
634
+ x = blk(x, x_size, params["rpi_sa"], params["attn_mask"])
635
+
636
+ x = self.overlap_attn(x, x_size, params["rpi_oca"])
637
+
638
+ if self.downsample is not None:
639
+ x = self.downsample(x)
640
+ return x
641
+
642
+
643
+ class RHAG(nn.Module):
644
+ """Residual Hybrid Attention Group (RHAG).
645
+ Args:
646
+ dim (int): Number of input channels.
647
+ input_resolution (tuple[int]): Input resolution.
648
+ depth (int): Number of blocks.
649
+ num_heads (int): Number of attention heads.
650
+ window_size (int): Local window size.
651
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
652
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
653
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
654
+ drop (float, optional): Dropout rate. Default: 0.0
655
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
656
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
657
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
658
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
659
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
660
+ img_size: Input image size.
661
+ patch_size: Patch size.
662
+ resi_connection: The convolutional block before residual connection.
663
+ """
664
+
665
+ def __init__(
666
+ self,
667
+ dim,
668
+ input_resolution,
669
+ depth,
670
+ num_heads,
671
+ window_size,
672
+ compress_ratio,
673
+ squeeze_factor,
674
+ conv_scale,
675
+ overlap_ratio,
676
+ mlp_ratio=4.0,
677
+ qkv_bias=True,
678
+ qk_scale=None,
679
+ drop=0.0,
680
+ attn_drop=0.0,
681
+ drop_path=0.0,
682
+ norm_layer=nn.LayerNorm,
683
+ downsample=None,
684
+ use_checkpoint=False,
685
+ img_size=224,
686
+ patch_size=4,
687
+ resi_connection="1conv",
688
+ ):
689
+ super(RHAG, self).__init__()
690
+
691
+ self.dim = dim
692
+ self.input_resolution = input_resolution
693
+
694
+ self.residual_group = AttenBlocks(
695
+ dim=dim,
696
+ input_resolution=input_resolution,
697
+ depth=depth,
698
+ num_heads=num_heads,
699
+ window_size=window_size,
700
+ compress_ratio=compress_ratio,
701
+ squeeze_factor=squeeze_factor,
702
+ conv_scale=conv_scale,
703
+ overlap_ratio=overlap_ratio,
704
+ mlp_ratio=mlp_ratio,
705
+ qkv_bias=qkv_bias,
706
+ qk_scale=qk_scale,
707
+ drop=drop,
708
+ attn_drop=attn_drop,
709
+ drop_path=drop_path,
710
+ norm_layer=norm_layer,
711
+ downsample=downsample,
712
+ use_checkpoint=use_checkpoint,
713
+ )
714
+
715
+ if resi_connection == "1conv":
716
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
717
+ elif resi_connection == "identity":
718
+ self.conv = nn.Identity()
719
+
720
+ self.patch_embed = PatchEmbed(
721
+ img_size=img_size,
722
+ patch_size=patch_size,
723
+ in_chans=0,
724
+ embed_dim=dim,
725
+ norm_layer=None,
726
+ )
727
+
728
+ self.patch_unembed = PatchUnEmbed(
729
+ img_size=img_size,
730
+ patch_size=patch_size,
731
+ in_chans=0,
732
+ embed_dim=dim,
733
+ norm_layer=None,
734
+ )
735
+
736
+ def forward(self, x, x_size, params):
737
+ return (
738
+ self.patch_embed(
739
+ self.conv(
740
+ self.patch_unembed(self.residual_group(x, x_size, params), x_size)
741
+ )
742
+ )
743
+ + x
744
+ )
745
+
746
+
747
+ class PatchEmbed(nn.Module):
748
+ r"""Image to Patch Embedding
749
+ Args:
750
+ img_size (int): Image size. Default: 224.
751
+ patch_size (int): Patch token size. Default: 4.
752
+ in_chans (int): Number of input image channels. Default: 3.
753
+ embed_dim (int): Number of linear projection output channels. Default: 96.
754
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
755
+ """
756
+
757
+ def __init__(
758
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
759
+ ):
760
+ super().__init__()
761
+ img_size = to_2tuple(img_size)
762
+ patch_size = to_2tuple(patch_size)
763
+ patches_resolution = [
764
+ img_size[0] // patch_size[0], # type: ignore
765
+ img_size[1] // patch_size[1], # type: ignore
766
+ ]
767
+ self.img_size = img_size
768
+ self.patch_size = patch_size
769
+ self.patches_resolution = patches_resolution
770
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
771
+
772
+ self.in_chans = in_chans
773
+ self.embed_dim = embed_dim
774
+
775
+ if norm_layer is not None:
776
+ self.norm = norm_layer(embed_dim)
777
+ else:
778
+ self.norm = None
779
+
780
+ def forward(self, x):
781
+ x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
782
+ if self.norm is not None:
783
+ x = self.norm(x)
784
+ return x
785
+
786
+
787
+ class PatchUnEmbed(nn.Module):
788
+ r"""Image to Patch Unembedding
789
+ Args:
790
+ img_size (int): Image size. Default: 224.
791
+ patch_size (int): Patch token size. Default: 4.
792
+ in_chans (int): Number of input image channels. Default: 3.
793
+ embed_dim (int): Number of linear projection output channels. Default: 96.
794
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
795
+ """
796
+
797
+ def __init__(
798
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
799
+ ):
800
+ super().__init__()
801
+ img_size = to_2tuple(img_size)
802
+ patch_size = to_2tuple(patch_size)
803
+ patches_resolution = [
804
+ img_size[0] // patch_size[0], # type: ignore
805
+ img_size[1] // patch_size[1], # type: ignore
806
+ ]
807
+ self.img_size = img_size
808
+ self.patch_size = patch_size
809
+ self.patches_resolution = patches_resolution
810
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
811
+
812
+ self.in_chans = in_chans
813
+ self.embed_dim = embed_dim
814
+
815
+ def forward(self, x, x_size):
816
+ x = (
817
+ x.transpose(1, 2)
818
+ .contiguous()
819
+ .view(x.shape[0], self.embed_dim, x_size[0], x_size[1])
820
+ ) # b Ph*Pw c
821
+ return x
822
+
823
+
824
+ class Upsample(nn.Sequential):
825
+ """Upsample module.
826
+ Args:
827
+ scale (int): Scale factor. Supported scales: 2^n and 3.
828
+ num_feat (int): Channel number of intermediate features.
829
+ """
830
+
831
+ def __init__(self, scale, num_feat):
832
+ m = []
833
+ if (scale & (scale - 1)) == 0: # scale = 2^n
834
+ for _ in range(int(math.log(scale, 2))):
835
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
836
+ m.append(nn.PixelShuffle(2))
837
+ elif scale == 3:
838
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
839
+ m.append(nn.PixelShuffle(3))
840
+ else:
841
+ raise ValueError(
842
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
843
+ )
844
+ super(Upsample, self).__init__(*m)
845
+
846
+
847
+ class HAT(nn.Module):
848
+ r"""Hybrid Attention Transformer
849
+ A PyTorch implementation of : `Activating More Pixels in Image Super-Resolution Transformer`.
850
+ Some codes are based on SwinIR.
851
+ Args:
852
+ img_size (int | tuple(int)): Input image size. Default 64
853
+ patch_size (int | tuple(int)): Patch size. Default: 1
854
+ in_chans (int): Number of input image channels. Default: 3
855
+ embed_dim (int): Patch embedding dimension. Default: 96
856
+ depths (tuple(int)): Depth of each Swin Transformer layer.
857
+ num_heads (tuple(int)): Number of attention heads in different layers.
858
+ window_size (int): Window size. Default: 7
859
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
860
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
861
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
862
+ drop_rate (float): Dropout rate. Default: 0
863
+ attn_drop_rate (float): Attention dropout rate. Default: 0
864
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
865
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
866
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
867
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
868
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
869
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
870
+ img_range: Image range. 1. or 255.
871
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
872
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
873
+ """
874
+
875
+ def __init__(
876
+ self,
877
+ state_dict,
878
+ **kwargs,
879
+ ):
880
+ super(HAT, self).__init__()
881
+
882
+ # Defaults
883
+ img_size = 64
884
+ patch_size = 1
885
+ in_chans = 3
886
+ embed_dim = 96
887
+ depths = (6, 6, 6, 6)
888
+ num_heads = (6, 6, 6, 6)
889
+ window_size = 7
890
+ compress_ratio = 3
891
+ squeeze_factor = 30
892
+ conv_scale = 0.01
893
+ overlap_ratio = 0.5
894
+ mlp_ratio = 4.0
895
+ qkv_bias = True
896
+ qk_scale = None
897
+ drop_rate = 0.0
898
+ attn_drop_rate = 0.0
899
+ drop_path_rate = 0.1
900
+ norm_layer = nn.LayerNorm
901
+ ape = False
902
+ patch_norm = True
903
+ use_checkpoint = False
904
+ upscale = 2
905
+ img_range = 1.0
906
+ upsampler = ""
907
+ resi_connection = "1conv"
908
+
909
+ self.state = state_dict
910
+ self.model_arch = "HAT"
911
+ self.sub_type = "SR"
912
+ self.supports_fp16 = False
913
+ self.support_bf16 = True
914
+ self.min_size_restriction = 16
915
+
916
+ state_keys = list(state_dict.keys())
917
+
918
+ num_feat = state_dict["conv_last.weight"].shape[1]
919
+ in_chans = state_dict["conv_first.weight"].shape[1]
920
+ num_out_ch = state_dict["conv_last.weight"].shape[0]
921
+ embed_dim = state_dict["conv_first.weight"].shape[0]
922
+
923
+ if "conv_before_upsample.0.weight" in state_keys:
924
+ if "conv_up1.weight" in state_keys:
925
+ upsampler = "nearest+conv"
926
+ else:
927
+ upsampler = "pixelshuffle"
928
+ supports_fp16 = False
929
+ elif "upsample.0.weight" in state_keys:
930
+ upsampler = "pixelshuffledirect"
931
+ else:
932
+ upsampler = ""
933
+ upscale = 1
934
+ if upsampler == "nearest+conv":
935
+ upsample_keys = [
936
+ x for x in state_keys if "conv_up" in x and "bias" not in x
937
+ ]
938
+
939
+ for upsample_key in upsample_keys:
940
+ upscale *= 2
941
+ elif upsampler == "pixelshuffle":
942
+ upsample_keys = [
943
+ x
944
+ for x in state_keys
945
+ if "upsample" in x and "conv" not in x and "bias" not in x
946
+ ]
947
+ for upsample_key in upsample_keys:
948
+ shape = self.state[upsample_key].shape[0]
949
+ upscale *= math.sqrt(shape // num_feat)
950
+ upscale = int(upscale)
951
+ elif upsampler == "pixelshuffledirect":
952
+ upscale = int(
953
+ math.sqrt(self.state["upsample.0.bias"].shape[0] // num_out_ch)
954
+ )
955
+
956
+ max_layer_num = 0
957
+ max_block_num = 0
958
+ for key in state_keys:
959
+ result = re.match(
960
+ r"layers.(\d*).residual_group.blocks.(\d*).conv_block.cab.0.weight", key
961
+ )
962
+ if result:
963
+ layer_num, block_num = result.groups()
964
+ max_layer_num = max(max_layer_num, int(layer_num))
965
+ max_block_num = max(max_block_num, int(block_num))
966
+
967
+ depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]
968
+
969
+ if (
970
+ "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
971
+ in state_keys
972
+ ):
973
+ num_heads_num = self.state[
974
+ "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
975
+ ].shape[-1]
976
+ num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
977
+ else:
978
+ num_heads = depths
979
+
980
+ mlp_ratio = float(
981
+ self.state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0]
982
+ / embed_dim
983
+ )
984
+
985
+ # TODO: could actually count the layers, but this should do
986
+ if "layers.0.conv.4.weight" in state_keys:
987
+ resi_connection = "3conv"
988
+ else:
989
+ resi_connection = "1conv"
990
+
991
+ window_size = int(math.sqrt(self.state["relative_position_index_SA"].shape[0]))
992
+
993
+ # Not sure if this is needed or used at all anywhere in HAT's config
994
+ if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
995
+ img_size = int(
996
+ math.sqrt(
997
+ self.state["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
998
+ )
999
+ * window_size
1000
+ )
1001
+
1002
+ self.window_size = window_size
1003
+ self.shift_size = window_size // 2
1004
+ self.overlap_ratio = overlap_ratio
1005
+
1006
+ self.in_nc = in_chans
1007
+ self.out_nc = num_out_ch
1008
+ self.num_feat = num_feat
1009
+ self.embed_dim = embed_dim
1010
+ self.num_heads = num_heads
1011
+ self.depths = depths
1012
+ self.window_size = window_size
1013
+ self.mlp_ratio = mlp_ratio
1014
+ self.scale = upscale
1015
+ self.upsampler = upsampler
1016
+ self.img_size = img_size
1017
+ self.img_range = img_range
1018
+ self.resi_connection = resi_connection
1019
+
1020
+ num_in_ch = in_chans
1021
+ # num_out_ch = in_chans
1022
+ # num_feat = 64
1023
+ self.img_range = img_range
1024
+ if in_chans == 3:
1025
+ rgb_mean = (0.4488, 0.4371, 0.4040)
1026
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
1027
+ else:
1028
+ self.mean = torch.zeros(1, 1, 1, 1)
1029
+ self.upscale = upscale
1030
+ self.upsampler = upsampler
1031
+
1032
+ # relative position index
1033
+ relative_position_index_SA = self.calculate_rpi_sa()
1034
+ relative_position_index_OCA = self.calculate_rpi_oca()
1035
+ self.register_buffer("relative_position_index_SA", relative_position_index_SA)
1036
+ self.register_buffer("relative_position_index_OCA", relative_position_index_OCA)
1037
+
1038
+ # ------------------------- 1, shallow feature extraction ------------------------- #
1039
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
1040
+
1041
+ # ------------------------- 2, deep feature extraction ------------------------- #
1042
+ self.num_layers = len(depths)
1043
+ self.embed_dim = embed_dim
1044
+ self.ape = ape
1045
+ self.patch_norm = patch_norm
1046
+ self.num_features = embed_dim
1047
+ self.mlp_ratio = mlp_ratio
1048
+
1049
+ # split image into non-overlapping patches
1050
+ self.patch_embed = PatchEmbed(
1051
+ img_size=img_size,
1052
+ patch_size=patch_size,
1053
+ in_chans=embed_dim,
1054
+ embed_dim=embed_dim,
1055
+ norm_layer=norm_layer if self.patch_norm else None,
1056
+ )
1057
+ num_patches = self.patch_embed.num_patches
1058
+ patches_resolution = self.patch_embed.patches_resolution
1059
+ self.patches_resolution = patches_resolution
1060
+
1061
+ # merge non-overlapping patches into image
1062
+ self.patch_unembed = PatchUnEmbed(
1063
+ img_size=img_size,
1064
+ patch_size=patch_size,
1065
+ in_chans=embed_dim,
1066
+ embed_dim=embed_dim,
1067
+ norm_layer=norm_layer if self.patch_norm else None,
1068
+ )
1069
+
1070
+ # absolute position embedding
1071
+ if self.ape:
1072
+ self.absolute_pos_embed = nn.Parameter( # type: ignore[arg-type]
1073
+ torch.zeros(1, num_patches, embed_dim)
1074
+ )
1075
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
1076
+
1077
+ self.pos_drop = nn.Dropout(p=drop_rate)
1078
+
1079
+ # stochastic depth
1080
+ dpr = [
1081
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
1082
+ ] # stochastic depth decay rule
1083
+
1084
+ # build Residual Hybrid Attention Groups (RHAG)
1085
+ self.layers = nn.ModuleList()
1086
+ for i_layer in range(self.num_layers):
1087
+ layer = RHAG(
1088
+ dim=embed_dim,
1089
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
1090
+ depth=depths[i_layer],
1091
+ num_heads=num_heads[i_layer],
1092
+ window_size=window_size,
1093
+ compress_ratio=compress_ratio,
1094
+ squeeze_factor=squeeze_factor,
1095
+ conv_scale=conv_scale,
1096
+ overlap_ratio=overlap_ratio,
1097
+ mlp_ratio=self.mlp_ratio,
1098
+ qkv_bias=qkv_bias,
1099
+ qk_scale=qk_scale,
1100
+ drop=drop_rate,
1101
+ attn_drop=attn_drop_rate,
1102
+ drop_path=dpr[
1103
+ sum(depths[:i_layer]) : sum(depths[: i_layer + 1]) # type: ignore
1104
+ ], # no impact on SR results
1105
+ norm_layer=norm_layer,
1106
+ downsample=None,
1107
+ use_checkpoint=use_checkpoint,
1108
+ img_size=img_size,
1109
+ patch_size=patch_size,
1110
+ resi_connection=resi_connection,
1111
+ )
1112
+ self.layers.append(layer)
1113
+ self.norm = norm_layer(self.num_features)
1114
+
1115
+ # build the last conv layer in deep feature extraction
1116
+ if resi_connection == "1conv":
1117
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1118
+ elif resi_connection == "identity":
1119
+ self.conv_after_body = nn.Identity()
1120
+
1121
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
1122
+ if self.upsampler == "pixelshuffle":
1123
+ # for classical SR
1124
+ self.conv_before_upsample = nn.Sequential(
1125
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1126
+ )
1127
+ self.upsample = Upsample(upscale, num_feat)
1128
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1129
+
1130
+ self.apply(self._init_weights)
1131
+ self.load_state_dict(self.state, strict=False)
1132
+
1133
+ def _init_weights(self, m):
1134
+ if isinstance(m, nn.Linear):
1135
+ trunc_normal_(m.weight, std=0.02)
1136
+ if isinstance(m, nn.Linear) and m.bias is not None:
1137
+ nn.init.constant_(m.bias, 0)
1138
+ elif isinstance(m, nn.LayerNorm):
1139
+ nn.init.constant_(m.bias, 0)
1140
+ nn.init.constant_(m.weight, 1.0)
1141
+
1142
+ def calculate_rpi_sa(self):
1143
+ # calculate relative position index for SA
1144
+ coords_h = torch.arange(self.window_size)
1145
+ coords_w = torch.arange(self.window_size)
1146
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
1147
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
1148
+ relative_coords = (
1149
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
1150
+ ) # 2, Wh*Ww, Wh*Ww
1151
+ relative_coords = relative_coords.permute(
1152
+ 1, 2, 0
1153
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
1154
+ relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
1155
+ relative_coords[:, :, 1] += self.window_size - 1
1156
+ relative_coords[:, :, 0] *= 2 * self.window_size - 1
1157
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
1158
+ return relative_position_index
1159
+
1160
+ def calculate_rpi_oca(self):
1161
+ # calculate relative position index for OCA
1162
+ window_size_ori = self.window_size
1163
+ window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
1164
+
1165
+ coords_h = torch.arange(window_size_ori)
1166
+ coords_w = torch.arange(window_size_ori)
1167
+ coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, ws, ws
1168
+ coords_ori_flatten = torch.flatten(coords_ori, 1) # 2, ws*ws
1169
+
1170
+ coords_h = torch.arange(window_size_ext)
1171
+ coords_w = torch.arange(window_size_ext)
1172
+ coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, wse, wse
1173
+ coords_ext_flatten = torch.flatten(coords_ext, 1) # 2, wse*wse
1174
+
1175
+ relative_coords = (
1176
+ coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None]
1177
+ ) # 2, ws*ws, wse*wse
1178
+
1179
+ relative_coords = relative_coords.permute(
1180
+ 1, 2, 0
1181
+ ).contiguous() # ws*ws, wse*wse, 2
1182
+ relative_coords[:, :, 0] += (
1183
+ window_size_ori - window_size_ext + 1
1184
+ ) # shift to start from 0
1185
+ relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
1186
+
1187
+ relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
1188
+ relative_position_index = relative_coords.sum(-1)
1189
+ return relative_position_index
1190
+
1191
+ def calculate_mask(self, x_size):
1192
+ # calculate attention mask for SW-MSA
1193
+ h, w = x_size
1194
+ img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
1195
+ h_slices = (
1196
+ slice(0, -self.window_size),
1197
+ slice(-self.window_size, -self.shift_size),
1198
+ slice(-self.shift_size, None),
1199
+ )
1200
+ w_slices = (
1201
+ slice(0, -self.window_size),
1202
+ slice(-self.window_size, -self.shift_size),
1203
+ slice(-self.shift_size, None),
1204
+ )
1205
+ cnt = 0
1206
+ for h in h_slices:
1207
+ for w in w_slices:
1208
+ img_mask[:, h, w, :] = cnt
1209
+ cnt += 1
1210
+
1211
+ mask_windows = window_partition(
1212
+ img_mask, self.window_size
1213
+ ) # nw, window_size, window_size, 1
1214
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
1215
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
1216
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
1217
+ attn_mask == 0, float(0.0)
1218
+ )
1219
+
1220
+ return attn_mask
1221
+
1222
+ @torch.jit.ignore # type: ignore
1223
+ def no_weight_decay(self):
1224
+ return {"absolute_pos_embed"}
1225
+
1226
+ @torch.jit.ignore # type: ignore
1227
+ def no_weight_decay_keywords(self):
1228
+ return {"relative_position_bias_table"}
1229
+
1230
+ def check_image_size(self, x):
1231
+ _, _, h, w = x.size()
1232
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
1233
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
1234
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
1235
+ return x
1236
+
1237
+ def forward_features(self, x):
1238
+ x_size = (x.shape[2], x.shape[3])
1239
+
1240
+ # Calculate attention mask and relative position index in advance to speed up inference.
1241
+ # The original code is very time-cosuming for large window size.
1242
+ attn_mask = self.calculate_mask(x_size).to(x.device)
1243
+ params = {
1244
+ "attn_mask": attn_mask,
1245
+ "rpi_sa": self.relative_position_index_SA,
1246
+ "rpi_oca": self.relative_position_index_OCA,
1247
+ }
1248
+
1249
+ x = self.patch_embed(x)
1250
+ if self.ape:
1251
+ x = x + self.absolute_pos_embed
1252
+ x = self.pos_drop(x)
1253
+
1254
+ for layer in self.layers:
1255
+ x = layer(x, x_size, params)
1256
+
1257
+ x = self.norm(x) # b seq_len c
1258
+ x = self.patch_unembed(x, x_size)
1259
+
1260
+ return x
1261
+
1262
+ def forward(self, x):
1263
+ H, W = x.shape[2:]
1264
+ self.mean = self.mean.type_as(x)
1265
+ x = (x - self.mean) * self.img_range
1266
+ x = self.check_image_size(x)
1267
+
1268
+ if self.upsampler == "pixelshuffle":
1269
+ # for classical SR
1270
+ x = self.conv_first(x)
1271
+ x = self.conv_after_body(self.forward_features(x)) + x
1272
+ x = self.conv_before_upsample(x)
1273
+ x = self.conv_last(self.upsample(x))
1274
+
1275
+ x = x / self.img_range + self.mean
1276
+
1277
+ return x[:, :, : H * self.upscale, : W * self.upscale]
comfy_extras/chainner_models/architecture/LICENSE-ESRGAN ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
comfy_extras/chainner_models/architecture/LICENSE-HAT ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Xiangyu Chen
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
comfy_extras/chainner_models/architecture/LICENSE-RealESRGAN ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2021, Xintao Wang
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
comfy_extras/chainner_models/architecture/LICENSE-SPSR ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2018-2022 BasicSR Authors
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
comfy_extras/chainner_models/architecture/LICENSE-SwiftSRGAN ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Creative Commons Legal Code
2
+
3
+ CC0 1.0 Universal
4
+
5
+ CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
6
+ LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
7
+ ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
8
+ INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
9
+ REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
10
+ PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
11
+ THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
12
+ HEREUNDER.
13
+
14
+ Statement of Purpose
15
+
16
+ The laws of most jurisdictions throughout the world automatically confer
17
+ exclusive Copyright and Related Rights (defined below) upon the creator
18
+ and subsequent owner(s) (each and all, an "owner") of an original work of
19
+ authorship and/or a database (each, a "Work").
20
+
21
+ Certain owners wish to permanently relinquish those rights to a Work for
22
+ the purpose of contributing to a commons of creative, cultural and
23
+ scientific works ("Commons") that the public can reliably and without fear
24
+ of later claims of infringement build upon, modify, incorporate in other
25
+ works, reuse and redistribute as freely as possible in any form whatsoever
26
+ and for any purposes, including without limitation commercial purposes.
27
+ These owners may contribute to the Commons to promote the ideal of a free
28
+ culture and the further production of creative, cultural and scientific
29
+ works, or to gain reputation or greater distribution for their Work in
30
+ part through the use and efforts of others.
31
+
32
+ For these and/or other purposes and motivations, and without any
33
+ expectation of additional consideration or compensation, the person
34
+ associating CC0 with a Work (the "Affirmer"), to the extent that he or she
35
+ is an owner of Copyright and Related Rights in the Work, voluntarily
36
+ elects to apply CC0 to the Work and publicly distribute the Work under its
37
+ terms, with knowledge of his or her Copyright and Related Rights in the
38
+ Work and the meaning and intended legal effect of CC0 on those rights.
39
+
40
+ 1. Copyright and Related Rights. A Work made available under CC0 may be
41
+ protected by copyright and related or neighboring rights ("Copyright and
42
+ Related Rights"). Copyright and Related Rights include, but are not
43
+ limited to, the following:
44
+
45
+ i. the right to reproduce, adapt, distribute, perform, display,
46
+ communicate, and translate a Work;
47
+ ii. moral rights retained by the original author(s) and/or performer(s);
48
+ iii. publicity and privacy rights pertaining to a person's image or
49
+ likeness depicted in a Work;
50
+ iv. rights protecting against unfair competition in regards to a Work,
51
+ subject to the limitations in paragraph 4(a), below;
52
+ v. rights protecting the extraction, dissemination, use and reuse of data
53
+ in a Work;
54
+ vi. database rights (such as those arising under Directive 96/9/EC of the
55
+ European Parliament and of the Council of 11 March 1996 on the legal
56
+ protection of databases, and under any national implementation
57
+ thereof, including any amended or successor version of such
58
+ directive); and
59
+ vii. other similar, equivalent or corresponding rights throughout the
60
+ world based on applicable law or treaty, and any national
61
+ implementations thereof.
62
+
63
+ 2. Waiver. To the greatest extent permitted by, but not in contravention
64
+ of, applicable law, Affirmer hereby overtly, fully, permanently,
65
+ irrevocably and unconditionally waives, abandons, and surrenders all of
66
+ Affirmer's Copyright and Related Rights and associated claims and causes
67
+ of action, whether now known or unknown (including existing as well as
68
+ future claims and causes of action), in the Work (i) in all territories
69
+ worldwide, (ii) for the maximum duration provided by applicable law or
70
+ treaty (including future time extensions), (iii) in any current or future
71
+ medium and for any number of copies, and (iv) for any purpose whatsoever,
72
+ including without limitation commercial, advertising or promotional
73
+ purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
74
+ member of the public at large and to the detriment of Affirmer's heirs and
75
+ successors, fully intending that such Waiver shall not be subject to
76
+ revocation, rescission, cancellation, termination, or any other legal or
77
+ equitable action to disrupt the quiet enjoyment of the Work by the public
78
+ as contemplated by Affirmer's express Statement of Purpose.
79
+
80
+ 3. Public License Fallback. Should any part of the Waiver for any reason
81
+ be judged legally invalid or ineffective under applicable law, then the
82
+ Waiver shall be preserved to the maximum extent permitted taking into
83
+ account Affirmer's express Statement of Purpose. In addition, to the
84
+ extent the Waiver is so judged Affirmer hereby grants to each affected
85
+ person a royalty-free, non transferable, non sublicensable, non exclusive,
86
+ irrevocable and unconditional license to exercise Affirmer's Copyright and
87
+ Related Rights in the Work (i) in all territories worldwide, (ii) for the
88
+ maximum duration provided by applicable law or treaty (including future
89
+ time extensions), (iii) in any current or future medium and for any number
90
+ of copies, and (iv) for any purpose whatsoever, including without
91
+ limitation commercial, advertising or promotional purposes (the
92
+ "License"). The License shall be deemed effective as of the date CC0 was
93
+ applied by Affirmer to the Work. Should any part of the License for any
94
+ reason be judged legally invalid or ineffective under applicable law, such
95
+ partial invalidity or ineffectiveness shall not invalidate the remainder
96
+ of the License, and in such case Affirmer hereby affirms that he or she
97
+ will not (i) exercise any of his or her remaining Copyright and Related
98
+ Rights in the Work or (ii) assert any associated claims and causes of
99
+ action with respect to the Work, in either case contrary to Affirmer's
100
+ express Statement of Purpose.
101
+
102
+ 4. Limitations and Disclaimers.
103
+
104
+ a. No trademark or patent rights held by Affirmer are waived, abandoned,
105
+ surrendered, licensed or otherwise affected by this document.
106
+ b. Affirmer offers the Work as-is and makes no representations or
107
+ warranties of any kind concerning the Work, express, implied,
108
+ statutory or otherwise, including without limitation warranties of
109
+ title, merchantability, fitness for a particular purpose, non
110
+ infringement, or the absence of latent or other defects, accuracy, or
111
+ the present or absence of errors, whether or not discoverable, all to
112
+ the greatest extent permissible under applicable law.
113
+ c. Affirmer disclaims responsibility for clearing rights of other persons
114
+ that may apply to the Work or any use thereof, including without
115
+ limitation any person's Copyright and Related Rights in the Work.
116
+ Further, Affirmer disclaims responsibility for obtaining any necessary
117
+ consents, permissions or other rights required for any use of the
118
+ Work.
119
+ d. Affirmer understands and acknowledges that Creative Commons is not a
120
+ party to this document and has no duty or obligation with respect to
121
+ this CC0 or use of the Work.
comfy_extras/chainner_models/architecture/LICENSE-Swin2SR ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2021] [SwinIR Authors]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
comfy_extras/chainner_models/architecture/LICENSE-SwinIR ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2021] [SwinIR Authors]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
comfy_extras/chainner_models/architecture/LICENSE-lama ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2021] Samsung Research
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
comfy_extras/chainner_models/architecture/LICENSE-mat ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## creative commons
2
+
3
+ # Attribution-NonCommercial 4.0 International
4
+
5
+ Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
6
+
7
+ ### Using Creative Commons Public Licenses
8
+
9
+ Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
10
+
11
+ * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
12
+
13
+ * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
14
+
15
+ ## Creative Commons Attribution-NonCommercial 4.0 International Public License
16
+
17
+ By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
18
+
19
+ ### Section 1 – Definitions.
20
+
21
+ a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
22
+
23
+ b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
24
+
25
+ c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
26
+
27
+ d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
28
+
29
+ e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
30
+
31
+ f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
32
+
33
+ g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
34
+
35
+ h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
36
+
37
+ i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
38
+
39
+ j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
40
+
41
+ k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
42
+
43
+ l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
44
+
45
+ ### Section 2 – Scope.
46
+
47
+ a. ___License grant.___
48
+
49
+ 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
50
+
51
+ A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
52
+
53
+ B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
54
+
55
+ 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
56
+
57
+ 3. __Term.__ The term of this Public License is specified in Section 6(a).
58
+
59
+ 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
60
+
61
+ 5. __Downstream recipients.__
62
+
63
+ A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
64
+
65
+ B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
66
+
67
+ 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
68
+
69
+ b. ___Other rights.___
70
+
71
+ 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
72
+
73
+ 2. Patent and trademark rights are not licensed under this Public License.
74
+
75
+ 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
76
+
77
+ ### Section 3 – License Conditions.
78
+
79
+ Your exercise of the Licensed Rights is expressly made subject to the following conditions.
80
+
81
+ a. ___Attribution.___
82
+
83
+ 1. If You Share the Licensed Material (including in modified form), You must:
84
+
85
+ A. retain the following if it is supplied by the Licensor with the Licensed Material:
86
+
87
+ i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
88
+
89
+ ii. a copyright notice;
90
+
91
+ iii. a notice that refers to this Public License;
92
+
93
+ iv. a notice that refers to the disclaimer of warranties;
94
+
95
+ v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
96
+
97
+ B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
98
+
99
+ C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
100
+
101
+ 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
102
+
103
+ 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
104
+
105
+ 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
106
+
107
+ ### Section 4 – Sui Generis Database Rights.
108
+
109
+ Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
110
+
111
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
112
+
113
+ b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
114
+
115
+ c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
116
+
117
+ For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
118
+
119
+ ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
120
+
121
+ a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
122
+
123
+ b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
124
+
125
+ c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
126
+
127
+ ### Section 6 – Term and Termination.
128
+
129
+ a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
130
+
131
+ b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
132
+
133
+ 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
134
+
135
+ 2. upon express reinstatement by the Licensor.
136
+
137
+ For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
138
+
139
+ c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
140
+
141
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
142
+
143
+ ### Section 7 – Other Terms and Conditions.
144
+
145
+ a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
146
+
147
+ b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
148
+
149
+ ### Section 8 – Interpretation.
150
+
151
+ a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
152
+
153
+ b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
154
+
155
+ c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
156
+
157
+ d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
158
+
159
+ > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
160
+ >
161
+ > Creative Commons may be contacted at creativecommons.org
comfy_extras/chainner_models/architecture/LaMa.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ """
3
+ Model adapted from advimman's lama project: https://github.com/advimman/lama
4
+ """
5
+
6
+ # Fast Fourier Convolution NeurIPS 2020
7
+ # original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
8
+ # paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
9
+
10
+ from typing import List
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torchvision.transforms.functional import InterpolationMode, rotate
16
+
17
+
18
+ class LearnableSpatialTransformWrapper(nn.Module):
19
+ def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
20
+ super().__init__()
21
+ self.impl = impl
22
+ self.angle = torch.rand(1) * angle_init_range
23
+ if train_angle:
24
+ self.angle = nn.Parameter(self.angle, requires_grad=True)
25
+ self.pad_coef = pad_coef
26
+
27
+ def forward(self, x):
28
+ if torch.is_tensor(x):
29
+ return self.inverse_transform(self.impl(self.transform(x)), x)
30
+ elif isinstance(x, tuple):
31
+ x_trans = tuple(self.transform(elem) for elem in x)
32
+ y_trans = self.impl(x_trans)
33
+ return tuple(
34
+ self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)
35
+ )
36
+ else:
37
+ raise ValueError(f"Unexpected input type {type(x)}")
38
+
39
+ def transform(self, x):
40
+ height, width = x.shape[2:]
41
+ pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
42
+ x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode="reflect")
43
+ x_padded_rotated = rotate(
44
+ x_padded, self.angle.to(x_padded), InterpolationMode.BILINEAR, fill=0
45
+ )
46
+
47
+ return x_padded_rotated
48
+
49
+ def inverse_transform(self, y_padded_rotated, orig_x):
50
+ height, width = orig_x.shape[2:]
51
+ pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
52
+
53
+ y_padded = rotate(
54
+ y_padded_rotated,
55
+ -self.angle.to(y_padded_rotated),
56
+ InterpolationMode.BILINEAR,
57
+ fill=0,
58
+ )
59
+ y_height, y_width = y_padded.shape[2:]
60
+ y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
61
+ return y
62
+
63
+
64
+ class SELayer(nn.Module):
65
+ def __init__(self, channel, reduction=16):
66
+ super(SELayer, self).__init__()
67
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
68
+ self.fc = nn.Sequential(
69
+ nn.Linear(channel, channel // reduction, bias=False),
70
+ nn.ReLU(inplace=True),
71
+ nn.Linear(channel // reduction, channel, bias=False),
72
+ nn.Sigmoid(),
73
+ )
74
+
75
+ def forward(self, x):
76
+ b, c, _, _ = x.size()
77
+ y = self.avg_pool(x).view(b, c)
78
+ y = self.fc(y).view(b, c, 1, 1)
79
+ res = x * y.expand_as(x)
80
+ return res
81
+
82
+
83
+ class FourierUnit(nn.Module):
84
+ def __init__(
85
+ self,
86
+ in_channels,
87
+ out_channels,
88
+ groups=1,
89
+ spatial_scale_factor=None,
90
+ spatial_scale_mode="bilinear",
91
+ spectral_pos_encoding=False,
92
+ use_se=False,
93
+ se_kwargs=None,
94
+ ffc3d=False,
95
+ fft_norm="ortho",
96
+ ):
97
+ # bn_layer not used
98
+ super(FourierUnit, self).__init__()
99
+ self.groups = groups
100
+
101
+ self.conv_layer = torch.nn.Conv2d(
102
+ in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
103
+ out_channels=out_channels * 2,
104
+ kernel_size=1,
105
+ stride=1,
106
+ padding=0,
107
+ groups=self.groups,
108
+ bias=False,
109
+ )
110
+ self.bn = torch.nn.BatchNorm2d(out_channels * 2)
111
+ self.relu = torch.nn.ReLU(inplace=True)
112
+
113
+ # squeeze and excitation block
114
+ self.use_se = use_se
115
+ if use_se:
116
+ if se_kwargs is None:
117
+ se_kwargs = {}
118
+ self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
119
+
120
+ self.spatial_scale_factor = spatial_scale_factor
121
+ self.spatial_scale_mode = spatial_scale_mode
122
+ self.spectral_pos_encoding = spectral_pos_encoding
123
+ self.ffc3d = ffc3d
124
+ self.fft_norm = fft_norm
125
+
126
+ def forward(self, x):
127
+ half_check = False
128
+ if x.type() == "torch.cuda.HalfTensor":
129
+ # half only works on gpu anyway
130
+ half_check = True
131
+
132
+ batch = x.shape[0]
133
+
134
+ if self.spatial_scale_factor is not None:
135
+ orig_size = x.shape[-2:]
136
+ x = F.interpolate(
137
+ x,
138
+ scale_factor=self.spatial_scale_factor,
139
+ mode=self.spatial_scale_mode,
140
+ align_corners=False,
141
+ )
142
+
143
+ # (batch, c, h, w/2+1, 2)
144
+ fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
145
+ if half_check == True:
146
+ ffted = torch.fft.rfftn(
147
+ x.float(), dim=fft_dim, norm=self.fft_norm
148
+ ) # .type(torch.cuda.HalfTensor)
149
+ else:
150
+ ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
151
+
152
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
153
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
154
+ ffted = ffted.view(
155
+ (
156
+ batch,
157
+ -1,
158
+ )
159
+ + ffted.size()[3:]
160
+ )
161
+
162
+ if self.spectral_pos_encoding:
163
+ height, width = ffted.shape[-2:]
164
+ coords_vert = (
165
+ torch.linspace(0, 1, height)[None, None, :, None]
166
+ .expand(batch, 1, height, width)
167
+ .to(ffted)
168
+ )
169
+ coords_hor = (
170
+ torch.linspace(0, 1, width)[None, None, None, :]
171
+ .expand(batch, 1, height, width)
172
+ .to(ffted)
173
+ )
174
+ ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
175
+
176
+ if self.use_se:
177
+ ffted = self.se(ffted)
178
+
179
+ if half_check == True:
180
+ ffted = self.conv_layer(ffted.half()) # (batch, c*2, h, w/2+1)
181
+ else:
182
+ ffted = self.conv_layer(
183
+ ffted
184
+ ) # .type(torch.cuda.FloatTensor) # (batch, c*2, h, w/2+1)
185
+
186
+ ffted = self.relu(self.bn(ffted))
187
+ # forcing to be always float
188
+ ffted = ffted.float()
189
+
190
+ ffted = (
191
+ ffted.view(
192
+ (
193
+ batch,
194
+ -1,
195
+ 2,
196
+ )
197
+ + ffted.size()[2:]
198
+ )
199
+ .permute(0, 1, 3, 4, 2)
200
+ .contiguous()
201
+ ) # (batch,c, t, h, w/2+1, 2)
202
+
203
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
204
+
205
+ ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
206
+ output = torch.fft.irfftn(
207
+ ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm
208
+ )
209
+
210
+ if half_check == True:
211
+ output = output.half()
212
+
213
+ if self.spatial_scale_factor is not None:
214
+ output = F.interpolate(
215
+ output,
216
+ size=orig_size,
217
+ mode=self.spatial_scale_mode,
218
+ align_corners=False,
219
+ )
220
+
221
+ return output
222
+
223
+
224
+ class SpectralTransform(nn.Module):
225
+ def __init__(
226
+ self,
227
+ in_channels,
228
+ out_channels,
229
+ stride=1,
230
+ groups=1,
231
+ enable_lfu=True,
232
+ separable_fu=False,
233
+ **fu_kwargs,
234
+ ):
235
+ # bn_layer not used
236
+ super(SpectralTransform, self).__init__()
237
+ self.enable_lfu = enable_lfu
238
+ if stride == 2:
239
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
240
+ else:
241
+ self.downsample = nn.Identity()
242
+
243
+ self.stride = stride
244
+ self.conv1 = nn.Sequential(
245
+ nn.Conv2d(
246
+ in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False
247
+ ),
248
+ nn.BatchNorm2d(out_channels // 2),
249
+ nn.ReLU(inplace=True),
250
+ )
251
+ fu_class = FourierUnit
252
+ self.fu = fu_class(out_channels // 2, out_channels // 2, groups, **fu_kwargs)
253
+ if self.enable_lfu:
254
+ self.lfu = fu_class(out_channels // 2, out_channels // 2, groups)
255
+ self.conv2 = torch.nn.Conv2d(
256
+ out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False
257
+ )
258
+
259
+ def forward(self, x):
260
+ x = self.downsample(x)
261
+ x = self.conv1(x)
262
+ output = self.fu(x)
263
+
264
+ if self.enable_lfu:
265
+ _, c, h, _ = x.shape
266
+ split_no = 2
267
+ split_s = h // split_no
268
+ xs = torch.cat(
269
+ torch.split(x[:, : c // 4], split_s, dim=-2), dim=1
270
+ ).contiguous()
271
+ xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
272
+ xs = self.lfu(xs)
273
+ xs = xs.repeat(1, 1, split_no, split_no).contiguous()
274
+ else:
275
+ xs = 0
276
+
277
+ output = self.conv2(x + output + xs)
278
+
279
+ return output
280
+
281
+
282
+ class FFC(nn.Module):
283
+ def __init__(
284
+ self,
285
+ in_channels,
286
+ out_channels,
287
+ kernel_size,
288
+ ratio_gin,
289
+ ratio_gout,
290
+ stride=1,
291
+ padding=0,
292
+ dilation=1,
293
+ groups=1,
294
+ bias=False,
295
+ enable_lfu=True,
296
+ padding_type="reflect",
297
+ gated=False,
298
+ **spectral_kwargs,
299
+ ):
300
+ super(FFC, self).__init__()
301
+
302
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
303
+ self.stride = stride
304
+
305
+ in_cg = int(in_channels * ratio_gin)
306
+ in_cl = in_channels - in_cg
307
+ out_cg = int(out_channels * ratio_gout)
308
+ out_cl = out_channels - out_cg
309
+ # groups_g = 1 if groups == 1 else int(groups * ratio_gout)
310
+ # groups_l = 1 if groups == 1 else groups - groups_g
311
+
312
+ self.ratio_gin = ratio_gin
313
+ self.ratio_gout = ratio_gout
314
+ self.global_in_num = in_cg
315
+
316
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
317
+ self.convl2l = module(
318
+ in_cl,
319
+ out_cl,
320
+ kernel_size,
321
+ stride,
322
+ padding,
323
+ dilation,
324
+ groups,
325
+ bias,
326
+ padding_mode=padding_type,
327
+ )
328
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
329
+ self.convl2g = module(
330
+ in_cl,
331
+ out_cg,
332
+ kernel_size,
333
+ stride,
334
+ padding,
335
+ dilation,
336
+ groups,
337
+ bias,
338
+ padding_mode=padding_type,
339
+ )
340
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
341
+ self.convg2l = module(
342
+ in_cg,
343
+ out_cl,
344
+ kernel_size,
345
+ stride,
346
+ padding,
347
+ dilation,
348
+ groups,
349
+ bias,
350
+ padding_mode=padding_type,
351
+ )
352
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
353
+ self.convg2g = module(
354
+ in_cg,
355
+ out_cg,
356
+ stride,
357
+ 1 if groups == 1 else groups // 2,
358
+ enable_lfu,
359
+ **spectral_kwargs,
360
+ )
361
+
362
+ self.gated = gated
363
+ module = (
364
+ nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
365
+ )
366
+ self.gate = module(in_channels, 2, 1)
367
+
368
+ def forward(self, x):
369
+ x_l, x_g = x if type(x) is tuple else (x, 0)
370
+ out_xl, out_xg = 0, 0
371
+
372
+ if self.gated:
373
+ total_input_parts = [x_l]
374
+ if torch.is_tensor(x_g):
375
+ total_input_parts.append(x_g)
376
+ total_input = torch.cat(total_input_parts, dim=1)
377
+
378
+ gates = torch.sigmoid(self.gate(total_input))
379
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
380
+ else:
381
+ g2l_gate, l2g_gate = 1, 1
382
+
383
+ if self.ratio_gout != 1:
384
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
385
+ if self.ratio_gout != 0:
386
+ out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
387
+
388
+ return out_xl, out_xg
389
+
390
+
391
+ class FFC_BN_ACT(nn.Module):
392
+ def __init__(
393
+ self,
394
+ in_channels,
395
+ out_channels,
396
+ kernel_size,
397
+ ratio_gin,
398
+ ratio_gout,
399
+ stride=1,
400
+ padding=0,
401
+ dilation=1,
402
+ groups=1,
403
+ bias=False,
404
+ norm_layer=nn.BatchNorm2d,
405
+ activation_layer=nn.Identity,
406
+ padding_type="reflect",
407
+ enable_lfu=True,
408
+ **kwargs,
409
+ ):
410
+ super(FFC_BN_ACT, self).__init__()
411
+ self.ffc = FFC(
412
+ in_channels,
413
+ out_channels,
414
+ kernel_size,
415
+ ratio_gin,
416
+ ratio_gout,
417
+ stride,
418
+ padding,
419
+ dilation,
420
+ groups,
421
+ bias,
422
+ enable_lfu,
423
+ padding_type=padding_type,
424
+ **kwargs,
425
+ )
426
+ lnorm = nn.Identity if ratio_gout == 1 else norm_layer
427
+ gnorm = nn.Identity if ratio_gout == 0 else norm_layer
428
+ global_channels = int(out_channels * ratio_gout)
429
+ self.bn_l = lnorm(out_channels - global_channels)
430
+ self.bn_g = gnorm(global_channels)
431
+
432
+ lact = nn.Identity if ratio_gout == 1 else activation_layer
433
+ gact = nn.Identity if ratio_gout == 0 else activation_layer
434
+ self.act_l = lact(inplace=True)
435
+ self.act_g = gact(inplace=True)
436
+
437
+ def forward(self, x):
438
+ x_l, x_g = self.ffc(x)
439
+ x_l = self.act_l(self.bn_l(x_l))
440
+ x_g = self.act_g(self.bn_g(x_g))
441
+ return x_l, x_g
442
+
443
+
444
+ class FFCResnetBlock(nn.Module):
445
+ def __init__(
446
+ self,
447
+ dim,
448
+ padding_type,
449
+ norm_layer,
450
+ activation_layer=nn.ReLU,
451
+ dilation=1,
452
+ spatial_transform_kwargs=None,
453
+ inline=False,
454
+ **conv_kwargs,
455
+ ):
456
+ super().__init__()
457
+ self.conv1 = FFC_BN_ACT(
458
+ dim,
459
+ dim,
460
+ kernel_size=3,
461
+ padding=dilation,
462
+ dilation=dilation,
463
+ norm_layer=norm_layer,
464
+ activation_layer=activation_layer,
465
+ padding_type=padding_type,
466
+ **conv_kwargs,
467
+ )
468
+ self.conv2 = FFC_BN_ACT(
469
+ dim,
470
+ dim,
471
+ kernel_size=3,
472
+ padding=dilation,
473
+ dilation=dilation,
474
+ norm_layer=norm_layer,
475
+ activation_layer=activation_layer,
476
+ padding_type=padding_type,
477
+ **conv_kwargs,
478
+ )
479
+ if spatial_transform_kwargs is not None:
480
+ self.conv1 = LearnableSpatialTransformWrapper(
481
+ self.conv1, **spatial_transform_kwargs
482
+ )
483
+ self.conv2 = LearnableSpatialTransformWrapper(
484
+ self.conv2, **spatial_transform_kwargs
485
+ )
486
+ self.inline = inline
487
+
488
+ def forward(self, x):
489
+ if self.inline:
490
+ x_l, x_g = (
491
+ x[:, : -self.conv1.ffc.global_in_num],
492
+ x[:, -self.conv1.ffc.global_in_num :],
493
+ )
494
+ else:
495
+ x_l, x_g = x if type(x) is tuple else (x, 0)
496
+
497
+ id_l, id_g = x_l, x_g
498
+
499
+ x_l, x_g = self.conv1((x_l, x_g))
500
+ x_l, x_g = self.conv2((x_l, x_g))
501
+
502
+ x_l, x_g = id_l + x_l, id_g + x_g
503
+ out = x_l, x_g
504
+ if self.inline:
505
+ out = torch.cat(out, dim=1)
506
+ return out
507
+
508
+
509
+ class ConcatTupleLayer(nn.Module):
510
+ def forward(self, x):
511
+ assert isinstance(x, tuple)
512
+ x_l, x_g = x
513
+ assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
514
+ if not torch.is_tensor(x_g):
515
+ return x_l
516
+ return torch.cat(x, dim=1)
517
+
518
+
519
+ class FFCResNetGenerator(nn.Module):
520
+ def __init__(
521
+ self,
522
+ input_nc,
523
+ output_nc,
524
+ ngf=64,
525
+ n_downsampling=3,
526
+ n_blocks=18,
527
+ norm_layer=nn.BatchNorm2d,
528
+ padding_type="reflect",
529
+ activation_layer=nn.ReLU,
530
+ up_norm_layer=nn.BatchNorm2d,
531
+ up_activation=nn.ReLU(True),
532
+ init_conv_kwargs={},
533
+ downsample_conv_kwargs={},
534
+ resnet_conv_kwargs={},
535
+ spatial_transform_layers=None,
536
+ spatial_transform_kwargs={},
537
+ max_features=1024,
538
+ out_ffc=False,
539
+ out_ffc_kwargs={},
540
+ ):
541
+ assert n_blocks >= 0
542
+ super().__init__()
543
+ """
544
+ init_conv_kwargs = {'ratio_gin': 0, 'ratio_gout': 0, 'enable_lfu': False}
545
+ downsample_conv_kwargs = {'ratio_gin': '${generator.init_conv_kwargs.ratio_gout}', 'ratio_gout': '${generator.downsample_conv_kwargs.ratio_gin}', 'enable_lfu': False}
546
+ resnet_conv_kwargs = {'ratio_gin': 0.75, 'ratio_gout': '${generator.resnet_conv_kwargs.ratio_gin}', 'enable_lfu': False}
547
+ spatial_transform_kwargs = {}
548
+ out_ffc_kwargs = {}
549
+ """
550
+ """
551
+ print(input_nc, output_nc, ngf, n_downsampling, n_blocks, norm_layer,
552
+ padding_type, activation_layer,
553
+ up_norm_layer, up_activation,
554
+ spatial_transform_layers,
555
+ add_out_act, max_features, out_ffc, file=sys.stderr)
556
+
557
+ 4 3 64 3 18 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
558
+ reflect <class 'torch.nn.modules.activation.ReLU'>
559
+ <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
560
+ ReLU(inplace=True)
561
+ None sigmoid 1024 False
562
+ """
563
+ init_conv_kwargs = {"ratio_gin": 0, "ratio_gout": 0, "enable_lfu": False}
564
+ downsample_conv_kwargs = {"ratio_gin": 0, "ratio_gout": 0, "enable_lfu": False}
565
+ resnet_conv_kwargs = {
566
+ "ratio_gin": 0.75,
567
+ "ratio_gout": 0.75,
568
+ "enable_lfu": False,
569
+ }
570
+ spatial_transform_kwargs = {}
571
+ out_ffc_kwargs = {}
572
+
573
+ model = [
574
+ nn.ReflectionPad2d(3),
575
+ FFC_BN_ACT(
576
+ input_nc,
577
+ ngf,
578
+ kernel_size=7,
579
+ padding=0,
580
+ norm_layer=norm_layer,
581
+ activation_layer=activation_layer,
582
+ **init_conv_kwargs,
583
+ ),
584
+ ]
585
+
586
+ ### downsample
587
+ for i in range(n_downsampling):
588
+ mult = 2**i
589
+ if i == n_downsampling - 1:
590
+ cur_conv_kwargs = dict(downsample_conv_kwargs)
591
+ cur_conv_kwargs["ratio_gout"] = resnet_conv_kwargs.get("ratio_gin", 0)
592
+ else:
593
+ cur_conv_kwargs = downsample_conv_kwargs
594
+ model += [
595
+ FFC_BN_ACT(
596
+ min(max_features, ngf * mult),
597
+ min(max_features, ngf * mult * 2),
598
+ kernel_size=3,
599
+ stride=2,
600
+ padding=1,
601
+ norm_layer=norm_layer,
602
+ activation_layer=activation_layer,
603
+ **cur_conv_kwargs,
604
+ )
605
+ ]
606
+
607
+ mult = 2**n_downsampling
608
+ feats_num_bottleneck = min(max_features, ngf * mult)
609
+
610
+ ### resnet blocks
611
+ for i in range(n_blocks):
612
+ cur_resblock = FFCResnetBlock(
613
+ feats_num_bottleneck,
614
+ padding_type=padding_type,
615
+ activation_layer=activation_layer,
616
+ norm_layer=norm_layer,
617
+ **resnet_conv_kwargs,
618
+ )
619
+ if spatial_transform_layers is not None and i in spatial_transform_layers:
620
+ cur_resblock = LearnableSpatialTransformWrapper(
621
+ cur_resblock, **spatial_transform_kwargs
622
+ )
623
+ model += [cur_resblock]
624
+
625
+ model += [ConcatTupleLayer()]
626
+
627
+ ### upsample
628
+ for i in range(n_downsampling):
629
+ mult = 2 ** (n_downsampling - i)
630
+ model += [
631
+ nn.ConvTranspose2d(
632
+ min(max_features, ngf * mult),
633
+ min(max_features, int(ngf * mult / 2)),
634
+ kernel_size=3,
635
+ stride=2,
636
+ padding=1,
637
+ output_padding=1,
638
+ ),
639
+ up_norm_layer(min(max_features, int(ngf * mult / 2))),
640
+ up_activation,
641
+ ]
642
+
643
+ if out_ffc:
644
+ model += [
645
+ FFCResnetBlock(
646
+ ngf,
647
+ padding_type=padding_type,
648
+ activation_layer=activation_layer,
649
+ norm_layer=norm_layer,
650
+ inline=True,
651
+ **out_ffc_kwargs,
652
+ )
653
+ ]
654
+
655
+ model += [
656
+ nn.ReflectionPad2d(3),
657
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
658
+ ]
659
+ model.append(nn.Sigmoid())
660
+ self.model = nn.Sequential(*model)
661
+
662
+ def forward(self, image, mask):
663
+ return self.model(torch.cat([image, mask], dim=1))
664
+
665
+
666
+ class LaMa(nn.Module):
667
+ def __init__(self, state_dict) -> None:
668
+ super(LaMa, self).__init__()
669
+ self.model_arch = "LaMa"
670
+ self.sub_type = "Inpaint"
671
+ self.in_nc = 4
672
+ self.out_nc = 3
673
+ self.scale = 1
674
+
675
+ self.min_size = None
676
+ self.pad_mod = 8
677
+ self.pad_to_square = False
678
+
679
+ self.model = FFCResNetGenerator(self.in_nc, self.out_nc)
680
+ self.state = {
681
+ k.replace("generator.model", "model.model"): v
682
+ for k, v in state_dict.items()
683
+ }
684
+
685
+ self.supports_fp16 = False
686
+ self.support_bf16 = True
687
+
688
+ self.load_state_dict(self.state, strict=False)
689
+
690
+ def forward(self, img, mask):
691
+ masked_img = img * (1 - mask)
692
+ inpainted_mask = mask * self.model.forward(masked_img, mask)
693
+ result = inpainted_mask + (1 - mask) * img
694
+ return result
comfy_extras/chainner_models/architecture/MAT.py ADDED
@@ -0,0 +1,1636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ """Original MAT project is copyright of fenglingwb: https://github.com/fenglinglwb/MAT
3
+ Code used for this implementation of MAT is modified from lama-cleaner,
4
+ copyright of Sanster: https://github.com/fenglinglwb/MAT"""
5
+
6
+ import random
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint as checkpoint
13
+
14
+ from .mat.utils import (
15
+ Conv2dLayer,
16
+ FullyConnectedLayer,
17
+ activation_funcs,
18
+ bias_act,
19
+ conv2d_resample,
20
+ normalize_2nd_moment,
21
+ setup_filter,
22
+ to_2tuple,
23
+ upsample2d,
24
+ )
25
+
26
+
27
+ class ModulatedConv2d(nn.Module):
28
+ def __init__(
29
+ self,
30
+ in_channels, # Number of input channels.
31
+ out_channels, # Number of output channels.
32
+ kernel_size, # Width and height of the convolution kernel.
33
+ style_dim, # dimension of the style code
34
+ demodulate=True, # perfrom demodulation
35
+ up=1, # Integer upsampling factor.
36
+ down=1, # Integer downsampling factor.
37
+ resample_filter=[
38
+ 1,
39
+ 3,
40
+ 3,
41
+ 1,
42
+ ], # Low-pass filter to apply when resampling activations.
43
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
44
+ ):
45
+ super().__init__()
46
+ self.demodulate = demodulate
47
+
48
+ self.weight = torch.nn.Parameter(
49
+ torch.randn([1, out_channels, in_channels, kernel_size, kernel_size])
50
+ )
51
+ self.out_channels = out_channels
52
+ self.kernel_size = kernel_size
53
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
54
+ self.padding = self.kernel_size // 2
55
+ self.up = up
56
+ self.down = down
57
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
58
+ self.conv_clamp = conv_clamp
59
+
60
+ self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1)
61
+
62
+ def forward(self, x, style):
63
+ batch, in_channels, height, width = x.shape
64
+ style = self.affine(style).view(batch, 1, in_channels, 1, 1).to(x.device)
65
+ weight = self.weight.to(x.device) * self.weight_gain * style
66
+
67
+ if self.demodulate:
68
+ decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt()
69
+ weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1)
70
+
71
+ weight = weight.view(
72
+ batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size
73
+ )
74
+ x = x.view(1, batch * in_channels, height, width)
75
+ x = conv2d_resample(
76
+ x=x,
77
+ w=weight,
78
+ f=self.resample_filter,
79
+ up=self.up,
80
+ down=self.down,
81
+ padding=self.padding,
82
+ groups=batch,
83
+ )
84
+ out = x.view(batch, self.out_channels, *x.shape[2:])
85
+
86
+ return out
87
+
88
+
89
+ class StyleConv(torch.nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_channels, # Number of input channels.
93
+ out_channels, # Number of output channels.
94
+ style_dim, # Intermediate latent (W) dimensionality.
95
+ resolution, # Resolution of this layer.
96
+ kernel_size=3, # Convolution kernel size.
97
+ up=1, # Integer upsampling factor.
98
+ use_noise=False, # Enable noise input?
99
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
100
+ resample_filter=[
101
+ 1,
102
+ 3,
103
+ 3,
104
+ 1,
105
+ ], # Low-pass filter to apply when resampling activations.
106
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
107
+ demodulate=True, # perform demodulation
108
+ ):
109
+ super().__init__()
110
+
111
+ self.conv = ModulatedConv2d(
112
+ in_channels=in_channels,
113
+ out_channels=out_channels,
114
+ kernel_size=kernel_size,
115
+ style_dim=style_dim,
116
+ demodulate=demodulate,
117
+ up=up,
118
+ resample_filter=resample_filter,
119
+ conv_clamp=conv_clamp,
120
+ )
121
+
122
+ self.use_noise = use_noise
123
+ self.resolution = resolution
124
+ if use_noise:
125
+ self.register_buffer("noise_const", torch.randn([resolution, resolution]))
126
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
127
+
128
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
129
+ self.activation = activation
130
+ self.act_gain = activation_funcs[activation].def_gain
131
+ self.conv_clamp = conv_clamp
132
+
133
+ def forward(self, x, style, noise_mode="random", gain=1):
134
+ x = self.conv(x, style)
135
+
136
+ assert noise_mode in ["random", "const", "none"]
137
+
138
+ if self.use_noise:
139
+ if noise_mode == "random":
140
+ xh, xw = x.size()[-2:]
141
+ noise = (
142
+ torch.randn([x.shape[0], 1, xh, xw], device=x.device)
143
+ * self.noise_strength
144
+ )
145
+ if noise_mode == "const":
146
+ noise = self.noise_const * self.noise_strength
147
+ x = x + noise
148
+
149
+ act_gain = self.act_gain * gain
150
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
151
+ out = bias_act(
152
+ x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
153
+ )
154
+
155
+ return out
156
+
157
+
158
+ class ToRGB(torch.nn.Module):
159
+ def __init__(
160
+ self,
161
+ in_channels,
162
+ out_channels,
163
+ style_dim,
164
+ kernel_size=1,
165
+ resample_filter=[1, 3, 3, 1],
166
+ conv_clamp=None,
167
+ demodulate=False,
168
+ ):
169
+ super().__init__()
170
+
171
+ self.conv = ModulatedConv2d(
172
+ in_channels=in_channels,
173
+ out_channels=out_channels,
174
+ kernel_size=kernel_size,
175
+ style_dim=style_dim,
176
+ demodulate=demodulate,
177
+ resample_filter=resample_filter,
178
+ conv_clamp=conv_clamp,
179
+ )
180
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
181
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
182
+ self.conv_clamp = conv_clamp
183
+
184
+ def forward(self, x, style, skip=None):
185
+ x = self.conv(x, style)
186
+ out = bias_act(x, self.bias, clamp=self.conv_clamp)
187
+
188
+ if skip is not None:
189
+ if skip.shape != out.shape:
190
+ skip = upsample2d(skip, self.resample_filter)
191
+ out = out + skip
192
+
193
+ return out
194
+
195
+
196
+ def get_style_code(a, b):
197
+ return torch.cat([a, b.to(a.device)], dim=1)
198
+
199
+
200
+ class DecBlockFirst(nn.Module):
201
+ def __init__(
202
+ self,
203
+ in_channels,
204
+ out_channels,
205
+ activation,
206
+ style_dim,
207
+ use_noise,
208
+ demodulate,
209
+ img_channels,
210
+ ):
211
+ super().__init__()
212
+ self.fc = FullyConnectedLayer(
213
+ in_features=in_channels * 2,
214
+ out_features=in_channels * 4**2,
215
+ activation=activation,
216
+ )
217
+ self.conv = StyleConv(
218
+ in_channels=in_channels,
219
+ out_channels=out_channels,
220
+ style_dim=style_dim,
221
+ resolution=4,
222
+ kernel_size=3,
223
+ use_noise=use_noise,
224
+ activation=activation,
225
+ demodulate=demodulate,
226
+ )
227
+ self.toRGB = ToRGB(
228
+ in_channels=out_channels,
229
+ out_channels=img_channels,
230
+ style_dim=style_dim,
231
+ kernel_size=1,
232
+ demodulate=False,
233
+ )
234
+
235
+ def forward(self, x, ws, gs, E_features, noise_mode="random"):
236
+ x = self.fc(x).view(x.shape[0], -1, 4, 4)
237
+ x = x + E_features[2]
238
+ style = get_style_code(ws[:, 0], gs)
239
+ x = self.conv(x, style, noise_mode=noise_mode)
240
+ style = get_style_code(ws[:, 1], gs)
241
+ img = self.toRGB(x, style, skip=None)
242
+
243
+ return x, img
244
+
245
+
246
+ class MappingNet(torch.nn.Module):
247
+ def __init__(
248
+ self,
249
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
250
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
251
+ w_dim, # Intermediate latent (W) dimensionality.
252
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
253
+ num_layers=8, # Number of mapping layers.
254
+ embed_features=None, # Label embedding dimensionality, None = same as w_dim.
255
+ layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
256
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
257
+ lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
258
+ w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
259
+ ):
260
+ super().__init__()
261
+ self.z_dim = z_dim
262
+ self.c_dim = c_dim
263
+ self.w_dim = w_dim
264
+ self.num_ws = num_ws
265
+ self.num_layers = num_layers
266
+ self.w_avg_beta = w_avg_beta
267
+
268
+ if embed_features is None:
269
+ embed_features = w_dim
270
+ if c_dim == 0:
271
+ embed_features = 0
272
+ if layer_features is None:
273
+ layer_features = w_dim
274
+ features_list = (
275
+ [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
276
+ )
277
+
278
+ if c_dim > 0:
279
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
280
+ for idx in range(num_layers):
281
+ in_features = features_list[idx]
282
+ out_features = features_list[idx + 1]
283
+ layer = FullyConnectedLayer(
284
+ in_features,
285
+ out_features,
286
+ activation=activation,
287
+ lr_multiplier=lr_multiplier,
288
+ )
289
+ setattr(self, f"fc{idx}", layer)
290
+
291
+ if num_ws is not None and w_avg_beta is not None:
292
+ self.register_buffer("w_avg", torch.zeros([w_dim]))
293
+
294
+ def forward(
295
+ self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False
296
+ ):
297
+ # Embed, normalize, and concat inputs.
298
+ x = None
299
+ with torch.autograd.profiler.record_function("input"):
300
+ if self.z_dim > 0:
301
+ x = normalize_2nd_moment(z.to(torch.float32))
302
+ if self.c_dim > 0:
303
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
304
+ x = torch.cat([x, y], dim=1) if x is not None else y
305
+
306
+ # Main layers.
307
+ for idx in range(self.num_layers):
308
+ layer = getattr(self, f"fc{idx}")
309
+ x = layer(x)
310
+
311
+ # Update moving average of W.
312
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
313
+ with torch.autograd.profiler.record_function("update_w_avg"):
314
+ self.w_avg.copy_(
315
+ x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)
316
+ )
317
+
318
+ # Broadcast.
319
+ if self.num_ws is not None:
320
+ with torch.autograd.profiler.record_function("broadcast"):
321
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
322
+
323
+ # Apply truncation.
324
+ if truncation_psi != 1:
325
+ with torch.autograd.profiler.record_function("truncate"):
326
+ assert self.w_avg_beta is not None
327
+ if self.num_ws is None or truncation_cutoff is None:
328
+ x = self.w_avg.lerp(x, truncation_psi)
329
+ else:
330
+ x[:, :truncation_cutoff] = self.w_avg.lerp(
331
+ x[:, :truncation_cutoff], truncation_psi
332
+ )
333
+
334
+ return x
335
+
336
+
337
+ class DisFromRGB(nn.Module):
338
+ def __init__(
339
+ self, in_channels, out_channels, activation
340
+ ): # res = 2, ..., resolution_log2
341
+ super().__init__()
342
+ self.conv = Conv2dLayer(
343
+ in_channels=in_channels,
344
+ out_channels=out_channels,
345
+ kernel_size=1,
346
+ activation=activation,
347
+ )
348
+
349
+ def forward(self, x):
350
+ return self.conv(x)
351
+
352
+
353
+ class DisBlock(nn.Module):
354
+ def __init__(
355
+ self, in_channels, out_channels, activation
356
+ ): # res = 2, ..., resolution_log2
357
+ super().__init__()
358
+ self.conv0 = Conv2dLayer(
359
+ in_channels=in_channels,
360
+ out_channels=in_channels,
361
+ kernel_size=3,
362
+ activation=activation,
363
+ )
364
+ self.conv1 = Conv2dLayer(
365
+ in_channels=in_channels,
366
+ out_channels=out_channels,
367
+ kernel_size=3,
368
+ down=2,
369
+ activation=activation,
370
+ )
371
+ self.skip = Conv2dLayer(
372
+ in_channels=in_channels,
373
+ out_channels=out_channels,
374
+ kernel_size=1,
375
+ down=2,
376
+ bias=False,
377
+ )
378
+
379
+ def forward(self, x):
380
+ skip = self.skip(x, gain=np.sqrt(0.5))
381
+ x = self.conv0(x)
382
+ x = self.conv1(x, gain=np.sqrt(0.5))
383
+ out = skip + x
384
+
385
+ return out
386
+
387
+
388
+ def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512):
389
+ NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512}
390
+ return NF[2**stage]
391
+
392
+
393
+ class Mlp(nn.Module):
394
+ def __init__(
395
+ self,
396
+ in_features,
397
+ hidden_features=None,
398
+ out_features=None,
399
+ act_layer=nn.GELU,
400
+ drop=0.0,
401
+ ):
402
+ super().__init__()
403
+ out_features = out_features or in_features
404
+ hidden_features = hidden_features or in_features
405
+ self.fc1 = FullyConnectedLayer(
406
+ in_features=in_features, out_features=hidden_features, activation="lrelu"
407
+ )
408
+ self.fc2 = FullyConnectedLayer(
409
+ in_features=hidden_features, out_features=out_features
410
+ )
411
+
412
+ def forward(self, x):
413
+ x = self.fc1(x)
414
+ x = self.fc2(x)
415
+ return x
416
+
417
+
418
+ def window_partition(x, window_size):
419
+ """
420
+ Args:
421
+ x: (B, H, W, C)
422
+ window_size (int): window size
423
+ Returns:
424
+ windows: (num_windows*B, window_size, window_size, C)
425
+ """
426
+ B, H, W, C = x.shape
427
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
428
+ windows = (
429
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
430
+ )
431
+ return windows
432
+
433
+
434
+ def window_reverse(windows, window_size: int, H: int, W: int):
435
+ """
436
+ Args:
437
+ windows: (num_windows*B, window_size, window_size, C)
438
+ window_size (int): Window size
439
+ H (int): Height of image
440
+ W (int): Width of image
441
+ Returns:
442
+ x: (B, H, W, C)
443
+ """
444
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
445
+ # B = windows.shape[0] / (H * W / window_size / window_size)
446
+ x = windows.view(
447
+ B, H // window_size, W // window_size, window_size, window_size, -1
448
+ )
449
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
450
+ return x
451
+
452
+
453
+ class Conv2dLayerPartial(nn.Module):
454
+ def __init__(
455
+ self,
456
+ in_channels, # Number of input channels.
457
+ out_channels, # Number of output channels.
458
+ kernel_size, # Width and height of the convolution kernel.
459
+ bias=True, # Apply additive bias before the activation function?
460
+ activation="linear", # Activation function: 'relu', 'lrelu', etc.
461
+ up=1, # Integer upsampling factor.
462
+ down=1, # Integer downsampling factor.
463
+ resample_filter=[
464
+ 1,
465
+ 3,
466
+ 3,
467
+ 1,
468
+ ], # Low-pass filter to apply when resampling activations.
469
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
470
+ trainable=True, # Update the weights of this layer during training?
471
+ ):
472
+ super().__init__()
473
+ self.conv = Conv2dLayer(
474
+ in_channels,
475
+ out_channels,
476
+ kernel_size,
477
+ bias,
478
+ activation,
479
+ up,
480
+ down,
481
+ resample_filter,
482
+ conv_clamp,
483
+ trainable,
484
+ )
485
+
486
+ self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size)
487
+ self.slide_winsize = kernel_size**2
488
+ self.stride = down
489
+ self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0
490
+
491
+ def forward(self, x, mask=None):
492
+ if mask is not None:
493
+ with torch.no_grad():
494
+ if self.weight_maskUpdater.type() != x.type():
495
+ self.weight_maskUpdater = self.weight_maskUpdater.to(x)
496
+ update_mask = F.conv2d(
497
+ mask,
498
+ self.weight_maskUpdater,
499
+ bias=None,
500
+ stride=self.stride,
501
+ padding=self.padding,
502
+ )
503
+ mask_ratio = self.slide_winsize / (update_mask + 1e-8)
504
+ update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1
505
+ mask_ratio = torch.mul(mask_ratio, update_mask)
506
+ x = self.conv(x)
507
+ x = torch.mul(x, mask_ratio)
508
+ return x, update_mask
509
+ else:
510
+ x = self.conv(x)
511
+ return x, None
512
+
513
+
514
+ class WindowAttention(nn.Module):
515
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
516
+ It supports both of shifted and non-shifted window.
517
+ Args:
518
+ dim (int): Number of input channels.
519
+ window_size (tuple[int]): The height and width of the window.
520
+ num_heads (int): Number of attention heads.
521
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
522
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
523
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
524
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
525
+ """
526
+
527
+ def __init__(
528
+ self,
529
+ dim,
530
+ window_size,
531
+ num_heads,
532
+ down_ratio=1,
533
+ qkv_bias=True,
534
+ qk_scale=None,
535
+ attn_drop=0.0,
536
+ proj_drop=0.0,
537
+ ):
538
+ super().__init__()
539
+ self.dim = dim
540
+ self.window_size = window_size # Wh, Ww
541
+ self.num_heads = num_heads
542
+ head_dim = dim // num_heads
543
+ self.scale = qk_scale or head_dim**-0.5
544
+
545
+ self.q = FullyConnectedLayer(in_features=dim, out_features=dim)
546
+ self.k = FullyConnectedLayer(in_features=dim, out_features=dim)
547
+ self.v = FullyConnectedLayer(in_features=dim, out_features=dim)
548
+ self.proj = FullyConnectedLayer(in_features=dim, out_features=dim)
549
+
550
+ self.softmax = nn.Softmax(dim=-1)
551
+
552
+ def forward(self, x, mask_windows=None, mask=None):
553
+ """
554
+ Args:
555
+ x: input features with shape of (num_windows*B, N, C)
556
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
557
+ """
558
+ B_, N, C = x.shape
559
+ norm_x = F.normalize(x, p=2.0, dim=-1)
560
+ q = (
561
+ self.q(norm_x)
562
+ .reshape(B_, N, self.num_heads, C // self.num_heads)
563
+ .permute(0, 2, 1, 3)
564
+ )
565
+ k = (
566
+ self.k(norm_x)
567
+ .view(B_, -1, self.num_heads, C // self.num_heads)
568
+ .permute(0, 2, 3, 1)
569
+ )
570
+ v = (
571
+ self.v(x)
572
+ .view(B_, -1, self.num_heads, C // self.num_heads)
573
+ .permute(0, 2, 1, 3)
574
+ )
575
+
576
+ attn = (q @ k) * self.scale
577
+
578
+ if mask is not None:
579
+ nW = mask.shape[0]
580
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
581
+ 1
582
+ ).unsqueeze(0)
583
+ attn = attn.view(-1, self.num_heads, N, N)
584
+
585
+ if mask_windows is not None:
586
+ attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1)
587
+ attn = attn + attn_mask_windows.masked_fill(
588
+ attn_mask_windows == 0, float(-100.0)
589
+ ).masked_fill(attn_mask_windows == 1, float(0.0))
590
+ with torch.no_grad():
591
+ mask_windows = torch.clamp(
592
+ torch.sum(mask_windows, dim=1, keepdim=True), 0, 1
593
+ ).repeat(1, N, 1)
594
+
595
+ attn = self.softmax(attn)
596
+
597
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
598
+ x = self.proj(x)
599
+ return x, mask_windows
600
+
601
+
602
+ class SwinTransformerBlock(nn.Module):
603
+ r"""Swin Transformer Block.
604
+ Args:
605
+ dim (int): Number of input channels.
606
+ input_resolution (tuple[int]): Input resulotion.
607
+ num_heads (int): Number of attention heads.
608
+ window_size (int): Window size.
609
+ shift_size (int): Shift size for SW-MSA.
610
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
611
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
612
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
613
+ drop (float, optional): Dropout rate. Default: 0.0
614
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
615
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
616
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
617
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
618
+ """
619
+
620
+ def __init__(
621
+ self,
622
+ dim,
623
+ input_resolution,
624
+ num_heads,
625
+ down_ratio=1,
626
+ window_size=7,
627
+ shift_size=0,
628
+ mlp_ratio=4.0,
629
+ qkv_bias=True,
630
+ qk_scale=None,
631
+ drop=0.0,
632
+ attn_drop=0.0,
633
+ drop_path=0.0,
634
+ act_layer=nn.GELU,
635
+ norm_layer=nn.LayerNorm,
636
+ ):
637
+ super().__init__()
638
+ self.dim = dim
639
+ self.input_resolution = input_resolution
640
+ self.num_heads = num_heads
641
+ self.window_size = window_size
642
+ self.shift_size = shift_size
643
+ self.mlp_ratio = mlp_ratio
644
+ if min(self.input_resolution) <= self.window_size:
645
+ # if window size is larger than input resolution, we don't partition windows
646
+ self.shift_size = 0
647
+ self.window_size = min(self.input_resolution)
648
+ assert (
649
+ 0 <= self.shift_size < self.window_size
650
+ ), "shift_size must in 0-window_size"
651
+
652
+ if self.shift_size > 0:
653
+ down_ratio = 1
654
+ self.attn = WindowAttention(
655
+ dim,
656
+ window_size=to_2tuple(self.window_size),
657
+ num_heads=num_heads,
658
+ down_ratio=down_ratio,
659
+ qkv_bias=qkv_bias,
660
+ qk_scale=qk_scale,
661
+ attn_drop=attn_drop,
662
+ proj_drop=drop,
663
+ )
664
+
665
+ self.fuse = FullyConnectedLayer(
666
+ in_features=dim * 2, out_features=dim, activation="lrelu"
667
+ )
668
+
669
+ mlp_hidden_dim = int(dim * mlp_ratio)
670
+ self.mlp = Mlp(
671
+ in_features=dim,
672
+ hidden_features=mlp_hidden_dim,
673
+ act_layer=act_layer,
674
+ drop=drop,
675
+ )
676
+
677
+ if self.shift_size > 0:
678
+ attn_mask = self.calculate_mask(self.input_resolution)
679
+ else:
680
+ attn_mask = None
681
+
682
+ self.register_buffer("attn_mask", attn_mask)
683
+
684
+ def calculate_mask(self, x_size):
685
+ # calculate attention mask for SW-MSA
686
+ H, W = x_size
687
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
688
+ h_slices = (
689
+ slice(0, -self.window_size),
690
+ slice(-self.window_size, -self.shift_size),
691
+ slice(-self.shift_size, None),
692
+ )
693
+ w_slices = (
694
+ slice(0, -self.window_size),
695
+ slice(-self.window_size, -self.shift_size),
696
+ slice(-self.shift_size, None),
697
+ )
698
+ cnt = 0
699
+ for h in h_slices:
700
+ for w in w_slices:
701
+ img_mask[:, h, w, :] = cnt
702
+ cnt += 1
703
+
704
+ mask_windows = window_partition(
705
+ img_mask, self.window_size
706
+ ) # nW, window_size, window_size, 1
707
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
708
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
709
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
710
+ attn_mask == 0, float(0.0)
711
+ )
712
+
713
+ return attn_mask
714
+
715
+ def forward(self, x, x_size, mask=None):
716
+ # H, W = self.input_resolution
717
+ H, W = x_size
718
+ B, _, C = x.shape
719
+ # assert L == H * W, "input feature has wrong size"
720
+
721
+ shortcut = x
722
+ x = x.view(B, H, W, C)
723
+ if mask is not None:
724
+ mask = mask.view(B, H, W, 1)
725
+
726
+ # cyclic shift
727
+ if self.shift_size > 0:
728
+ shifted_x = torch.roll(
729
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
730
+ )
731
+ if mask is not None:
732
+ shifted_mask = torch.roll(
733
+ mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
734
+ )
735
+ else:
736
+ shifted_x = x
737
+ if mask is not None:
738
+ shifted_mask = mask
739
+
740
+ # partition windows
741
+ x_windows = window_partition(
742
+ shifted_x, self.window_size
743
+ ) # nW*B, window_size, window_size, C
744
+ x_windows = x_windows.view(
745
+ -1, self.window_size * self.window_size, C
746
+ ) # nW*B, window_size*window_size, C
747
+ if mask is not None:
748
+ mask_windows = window_partition(shifted_mask, self.window_size)
749
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1)
750
+ else:
751
+ mask_windows = None
752
+
753
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
754
+ if self.input_resolution == x_size:
755
+ attn_windows, mask_windows = self.attn(
756
+ x_windows, mask_windows, mask=self.attn_mask
757
+ ) # nW*B, window_size*window_size, C
758
+ else:
759
+ attn_windows, mask_windows = self.attn(
760
+ x_windows, mask_windows, mask=self.calculate_mask(x_size).to(x.device)
761
+ ) # nW*B, window_size*window_size, C
762
+
763
+ # merge windows
764
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
765
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
766
+ if mask is not None:
767
+ mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1)
768
+ shifted_mask = window_reverse(mask_windows, self.window_size, H, W)
769
+
770
+ # reverse cyclic shift
771
+ if self.shift_size > 0:
772
+ x = torch.roll(
773
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
774
+ )
775
+ if mask is not None:
776
+ mask = torch.roll(
777
+ shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
778
+ )
779
+ else:
780
+ x = shifted_x
781
+ if mask is not None:
782
+ mask = shifted_mask
783
+ x = x.view(B, H * W, C)
784
+ if mask is not None:
785
+ mask = mask.view(B, H * W, 1)
786
+
787
+ # FFN
788
+ x = self.fuse(torch.cat([shortcut, x], dim=-1))
789
+ x = self.mlp(x)
790
+
791
+ return x, mask
792
+
793
+
794
+ class PatchMerging(nn.Module):
795
+ def __init__(self, in_channels, out_channels, down=2):
796
+ super().__init__()
797
+ self.conv = Conv2dLayerPartial(
798
+ in_channels=in_channels,
799
+ out_channels=out_channels,
800
+ kernel_size=3,
801
+ activation="lrelu",
802
+ down=down,
803
+ )
804
+ self.down = down
805
+
806
+ def forward(self, x, x_size, mask=None):
807
+ x = token2feature(x, x_size)
808
+ if mask is not None:
809
+ mask = token2feature(mask, x_size)
810
+ x, mask = self.conv(x, mask)
811
+ if self.down != 1:
812
+ ratio = 1 / self.down
813
+ x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio))
814
+ x = feature2token(x)
815
+ if mask is not None:
816
+ mask = feature2token(mask)
817
+ return x, x_size, mask
818
+
819
+
820
+ class PatchUpsampling(nn.Module):
821
+ def __init__(self, in_channels, out_channels, up=2):
822
+ super().__init__()
823
+ self.conv = Conv2dLayerPartial(
824
+ in_channels=in_channels,
825
+ out_channels=out_channels,
826
+ kernel_size=3,
827
+ activation="lrelu",
828
+ up=up,
829
+ )
830
+ self.up = up
831
+
832
+ def forward(self, x, x_size, mask=None):
833
+ x = token2feature(x, x_size)
834
+ if mask is not None:
835
+ mask = token2feature(mask, x_size)
836
+ x, mask = self.conv(x, mask)
837
+ if self.up != 1:
838
+ x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up))
839
+ x = feature2token(x)
840
+ if mask is not None:
841
+ mask = feature2token(mask)
842
+ return x, x_size, mask
843
+
844
+
845
+ class BasicLayer(nn.Module):
846
+ """A basic Swin Transformer layer for one stage.
847
+ Args:
848
+ dim (int): Number of input channels.
849
+ input_resolution (tuple[int]): Input resolution.
850
+ depth (int): Number of blocks.
851
+ num_heads (int): Number of attention heads.
852
+ window_size (int): Local window size.
853
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
854
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
855
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
856
+ drop (float, optional): Dropout rate. Default: 0.0
857
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
858
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
859
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
860
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
861
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
862
+ """
863
+
864
+ def __init__(
865
+ self,
866
+ dim,
867
+ input_resolution,
868
+ depth,
869
+ num_heads,
870
+ window_size,
871
+ down_ratio=1,
872
+ mlp_ratio=2.0,
873
+ qkv_bias=True,
874
+ qk_scale=None,
875
+ drop=0.0,
876
+ attn_drop=0.0,
877
+ drop_path=0.0,
878
+ norm_layer=nn.LayerNorm,
879
+ downsample=None,
880
+ use_checkpoint=False,
881
+ ):
882
+ super().__init__()
883
+ self.dim = dim
884
+ self.input_resolution = input_resolution
885
+ self.depth = depth
886
+ self.use_checkpoint = use_checkpoint
887
+
888
+ # patch merging layer
889
+ if downsample is not None:
890
+ # self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
891
+ self.downsample = downsample
892
+ else:
893
+ self.downsample = None
894
+
895
+ # build blocks
896
+ self.blocks = nn.ModuleList(
897
+ [
898
+ SwinTransformerBlock(
899
+ dim=dim,
900
+ input_resolution=input_resolution,
901
+ num_heads=num_heads,
902
+ down_ratio=down_ratio,
903
+ window_size=window_size,
904
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
905
+ mlp_ratio=mlp_ratio,
906
+ qkv_bias=qkv_bias,
907
+ qk_scale=qk_scale,
908
+ drop=drop,
909
+ attn_drop=attn_drop,
910
+ drop_path=drop_path[i]
911
+ if isinstance(drop_path, list)
912
+ else drop_path,
913
+ norm_layer=norm_layer,
914
+ )
915
+ for i in range(depth)
916
+ ]
917
+ )
918
+
919
+ self.conv = Conv2dLayerPartial(
920
+ in_channels=dim, out_channels=dim, kernel_size=3, activation="lrelu"
921
+ )
922
+
923
+ def forward(self, x, x_size, mask=None):
924
+ if self.downsample is not None:
925
+ x, x_size, mask = self.downsample(x, x_size, mask)
926
+ identity = x
927
+ for blk in self.blocks:
928
+ if self.use_checkpoint:
929
+ x, mask = checkpoint.checkpoint(blk, x, x_size, mask)
930
+ else:
931
+ x, mask = blk(x, x_size, mask)
932
+ if mask is not None:
933
+ mask = token2feature(mask, x_size)
934
+ x, mask = self.conv(token2feature(x, x_size), mask)
935
+ x = feature2token(x) + identity
936
+ if mask is not None:
937
+ mask = feature2token(mask)
938
+ return x, x_size, mask
939
+
940
+
941
+ class ToToken(nn.Module):
942
+ def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1):
943
+ super().__init__()
944
+
945
+ self.proj = Conv2dLayerPartial(
946
+ in_channels=in_channels,
947
+ out_channels=dim,
948
+ kernel_size=kernel_size,
949
+ activation="lrelu",
950
+ )
951
+
952
+ def forward(self, x, mask):
953
+ x, mask = self.proj(x, mask)
954
+
955
+ return x, mask
956
+
957
+
958
+ class EncFromRGB(nn.Module):
959
+ def __init__(
960
+ self, in_channels, out_channels, activation
961
+ ): # res = 2, ..., resolution_log2
962
+ super().__init__()
963
+ self.conv0 = Conv2dLayer(
964
+ in_channels=in_channels,
965
+ out_channels=out_channels,
966
+ kernel_size=1,
967
+ activation=activation,
968
+ )
969
+ self.conv1 = Conv2dLayer(
970
+ in_channels=out_channels,
971
+ out_channels=out_channels,
972
+ kernel_size=3,
973
+ activation=activation,
974
+ )
975
+
976
+ def forward(self, x):
977
+ x = self.conv0(x)
978
+ x = self.conv1(x)
979
+
980
+ return x
981
+
982
+
983
+ class ConvBlockDown(nn.Module):
984
+ def __init__(
985
+ self, in_channels, out_channels, activation
986
+ ): # res = 2, ..., resolution_log
987
+ super().__init__()
988
+
989
+ self.conv0 = Conv2dLayer(
990
+ in_channels=in_channels,
991
+ out_channels=out_channels,
992
+ kernel_size=3,
993
+ activation=activation,
994
+ down=2,
995
+ )
996
+ self.conv1 = Conv2dLayer(
997
+ in_channels=out_channels,
998
+ out_channels=out_channels,
999
+ kernel_size=3,
1000
+ activation=activation,
1001
+ )
1002
+
1003
+ def forward(self, x):
1004
+ x = self.conv0(x)
1005
+ x = self.conv1(x)
1006
+
1007
+ return x
1008
+
1009
+
1010
+ def token2feature(x, x_size):
1011
+ B, _, C = x.shape
1012
+ h, w = x_size
1013
+ x = x.permute(0, 2, 1).reshape(B, C, h, w)
1014
+ return x
1015
+
1016
+
1017
+ def feature2token(x):
1018
+ B, C, _, _ = x.shape
1019
+ x = x.view(B, C, -1).transpose(1, 2)
1020
+ return x
1021
+
1022
+
1023
+ class Encoder(nn.Module):
1024
+ def __init__(
1025
+ self,
1026
+ res_log2,
1027
+ img_channels,
1028
+ activation,
1029
+ patch_size=5,
1030
+ channels=16,
1031
+ drop_path_rate=0.1,
1032
+ ):
1033
+ super().__init__()
1034
+
1035
+ self.resolution = []
1036
+
1037
+ for i in range(res_log2, 3, -1): # from input size to 16x16
1038
+ res = 2**i
1039
+ self.resolution.append(res)
1040
+ if i == res_log2:
1041
+ block = EncFromRGB(img_channels * 2 + 1, nf(i), activation)
1042
+ else:
1043
+ block = ConvBlockDown(nf(i + 1), nf(i), activation)
1044
+ setattr(self, "EncConv_Block_%dx%d" % (res, res), block)
1045
+
1046
+ def forward(self, x):
1047
+ out = {}
1048
+ for res in self.resolution:
1049
+ res_log2 = int(np.log2(res))
1050
+ x = getattr(self, "EncConv_Block_%dx%d" % (res, res))(x)
1051
+ out[res_log2] = x
1052
+
1053
+ return out
1054
+
1055
+
1056
+ class ToStyle(nn.Module):
1057
+ def __init__(self, in_channels, out_channels, activation, drop_rate):
1058
+ super().__init__()
1059
+ self.conv = nn.Sequential(
1060
+ Conv2dLayer(
1061
+ in_channels=in_channels,
1062
+ out_channels=in_channels,
1063
+ kernel_size=3,
1064
+ activation=activation,
1065
+ down=2,
1066
+ ),
1067
+ Conv2dLayer(
1068
+ in_channels=in_channels,
1069
+ out_channels=in_channels,
1070
+ kernel_size=3,
1071
+ activation=activation,
1072
+ down=2,
1073
+ ),
1074
+ Conv2dLayer(
1075
+ in_channels=in_channels,
1076
+ out_channels=in_channels,
1077
+ kernel_size=3,
1078
+ activation=activation,
1079
+ down=2,
1080
+ ),
1081
+ )
1082
+
1083
+ self.pool = nn.AdaptiveAvgPool2d(1)
1084
+ self.fc = FullyConnectedLayer(
1085
+ in_features=in_channels, out_features=out_channels, activation=activation
1086
+ )
1087
+ # self.dropout = nn.Dropout(drop_rate)
1088
+
1089
+ def forward(self, x):
1090
+ x = self.conv(x)
1091
+ x = self.pool(x)
1092
+ x = self.fc(x.flatten(start_dim=1))
1093
+ # x = self.dropout(x)
1094
+
1095
+ return x
1096
+
1097
+
1098
+ class DecBlockFirstV2(nn.Module):
1099
+ def __init__(
1100
+ self,
1101
+ res,
1102
+ in_channels,
1103
+ out_channels,
1104
+ activation,
1105
+ style_dim,
1106
+ use_noise,
1107
+ demodulate,
1108
+ img_channels,
1109
+ ):
1110
+ super().__init__()
1111
+ self.res = res
1112
+
1113
+ self.conv0 = Conv2dLayer(
1114
+ in_channels=in_channels,
1115
+ out_channels=in_channels,
1116
+ kernel_size=3,
1117
+ activation=activation,
1118
+ )
1119
+ self.conv1 = StyleConv(
1120
+ in_channels=in_channels,
1121
+ out_channels=out_channels,
1122
+ style_dim=style_dim,
1123
+ resolution=2**res,
1124
+ kernel_size=3,
1125
+ use_noise=use_noise,
1126
+ activation=activation,
1127
+ demodulate=demodulate,
1128
+ )
1129
+ self.toRGB = ToRGB(
1130
+ in_channels=out_channels,
1131
+ out_channels=img_channels,
1132
+ style_dim=style_dim,
1133
+ kernel_size=1,
1134
+ demodulate=False,
1135
+ )
1136
+
1137
+ def forward(self, x, ws, gs, E_features, noise_mode="random"):
1138
+ # x = self.fc(x).view(x.shape[0], -1, 4, 4)
1139
+ x = self.conv0(x)
1140
+ x = x + E_features[self.res]
1141
+ style = get_style_code(ws[:, 0], gs)
1142
+ x = self.conv1(x, style, noise_mode=noise_mode)
1143
+ style = get_style_code(ws[:, 1], gs)
1144
+ img = self.toRGB(x, style, skip=None)
1145
+
1146
+ return x, img
1147
+
1148
+
1149
+ class DecBlock(nn.Module):
1150
+ def __init__(
1151
+ self,
1152
+ res,
1153
+ in_channels,
1154
+ out_channels,
1155
+ activation,
1156
+ style_dim,
1157
+ use_noise,
1158
+ demodulate,
1159
+ img_channels,
1160
+ ): # res = 4, ..., resolution_log2
1161
+ super().__init__()
1162
+ self.res = res
1163
+
1164
+ self.conv0 = StyleConv(
1165
+ in_channels=in_channels,
1166
+ out_channels=out_channels,
1167
+ style_dim=style_dim,
1168
+ resolution=2**res,
1169
+ kernel_size=3,
1170
+ up=2,
1171
+ use_noise=use_noise,
1172
+ activation=activation,
1173
+ demodulate=demodulate,
1174
+ )
1175
+ self.conv1 = StyleConv(
1176
+ in_channels=out_channels,
1177
+ out_channels=out_channels,
1178
+ style_dim=style_dim,
1179
+ resolution=2**res,
1180
+ kernel_size=3,
1181
+ use_noise=use_noise,
1182
+ activation=activation,
1183
+ demodulate=demodulate,
1184
+ )
1185
+ self.toRGB = ToRGB(
1186
+ in_channels=out_channels,
1187
+ out_channels=img_channels,
1188
+ style_dim=style_dim,
1189
+ kernel_size=1,
1190
+ demodulate=False,
1191
+ )
1192
+
1193
+ def forward(self, x, img, ws, gs, E_features, noise_mode="random"):
1194
+ style = get_style_code(ws[:, self.res * 2 - 9], gs)
1195
+ x = self.conv0(x, style, noise_mode=noise_mode)
1196
+ x = x + E_features[self.res]
1197
+ style = get_style_code(ws[:, self.res * 2 - 8], gs)
1198
+ x = self.conv1(x, style, noise_mode=noise_mode)
1199
+ style = get_style_code(ws[:, self.res * 2 - 7], gs)
1200
+ img = self.toRGB(x, style, skip=img)
1201
+
1202
+ return x, img
1203
+
1204
+
1205
+ class Decoder(nn.Module):
1206
+ def __init__(
1207
+ self, res_log2, activation, style_dim, use_noise, demodulate, img_channels
1208
+ ):
1209
+ super().__init__()
1210
+ self.Dec_16x16 = DecBlockFirstV2(
1211
+ 4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels
1212
+ )
1213
+ for res in range(5, res_log2 + 1):
1214
+ setattr(
1215
+ self,
1216
+ "Dec_%dx%d" % (2**res, 2**res),
1217
+ DecBlock(
1218
+ res,
1219
+ nf(res - 1),
1220
+ nf(res),
1221
+ activation,
1222
+ style_dim,
1223
+ use_noise,
1224
+ demodulate,
1225
+ img_channels,
1226
+ ),
1227
+ )
1228
+ self.res_log2 = res_log2
1229
+
1230
+ def forward(self, x, ws, gs, E_features, noise_mode="random"):
1231
+ x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode)
1232
+ for res in range(5, self.res_log2 + 1):
1233
+ block = getattr(self, "Dec_%dx%d" % (2**res, 2**res))
1234
+ x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode)
1235
+
1236
+ return img
1237
+
1238
+
1239
+ class DecStyleBlock(nn.Module):
1240
+ def __init__(
1241
+ self,
1242
+ res,
1243
+ in_channels,
1244
+ out_channels,
1245
+ activation,
1246
+ style_dim,
1247
+ use_noise,
1248
+ demodulate,
1249
+ img_channels,
1250
+ ):
1251
+ super().__init__()
1252
+ self.res = res
1253
+
1254
+ self.conv0 = StyleConv(
1255
+ in_channels=in_channels,
1256
+ out_channels=out_channels,
1257
+ style_dim=style_dim,
1258
+ resolution=2**res,
1259
+ kernel_size=3,
1260
+ up=2,
1261
+ use_noise=use_noise,
1262
+ activation=activation,
1263
+ demodulate=demodulate,
1264
+ )
1265
+ self.conv1 = StyleConv(
1266
+ in_channels=out_channels,
1267
+ out_channels=out_channels,
1268
+ style_dim=style_dim,
1269
+ resolution=2**res,
1270
+ kernel_size=3,
1271
+ use_noise=use_noise,
1272
+ activation=activation,
1273
+ demodulate=demodulate,
1274
+ )
1275
+ self.toRGB = ToRGB(
1276
+ in_channels=out_channels,
1277
+ out_channels=img_channels,
1278
+ style_dim=style_dim,
1279
+ kernel_size=1,
1280
+ demodulate=False,
1281
+ )
1282
+
1283
+ def forward(self, x, img, style, skip, noise_mode="random"):
1284
+ x = self.conv0(x, style, noise_mode=noise_mode)
1285
+ x = x + skip
1286
+ x = self.conv1(x, style, noise_mode=noise_mode)
1287
+ img = self.toRGB(x, style, skip=img)
1288
+
1289
+ return x, img
1290
+
1291
+
1292
+ class FirstStage(nn.Module):
1293
+ def __init__(
1294
+ self,
1295
+ img_channels,
1296
+ img_resolution=256,
1297
+ dim=180,
1298
+ w_dim=512,
1299
+ use_noise=False,
1300
+ demodulate=True,
1301
+ activation="lrelu",
1302
+ ):
1303
+ super().__init__()
1304
+ res = 64
1305
+
1306
+ self.conv_first = Conv2dLayerPartial(
1307
+ in_channels=img_channels + 1,
1308
+ out_channels=dim,
1309
+ kernel_size=3,
1310
+ activation=activation,
1311
+ )
1312
+ self.enc_conv = nn.ModuleList()
1313
+ down_time = int(np.log2(img_resolution // res))
1314
+ # 根据图片尺寸构建 swim transformer 的层数
1315
+ for i in range(down_time): # from input size to 64
1316
+ self.enc_conv.append(
1317
+ Conv2dLayerPartial(
1318
+ in_channels=dim,
1319
+ out_channels=dim,
1320
+ kernel_size=3,
1321
+ down=2,
1322
+ activation=activation,
1323
+ )
1324
+ )
1325
+
1326
+ # from 64 -> 16 -> 64
1327
+ depths = [2, 3, 4, 3, 2]
1328
+ ratios = [1, 1 / 2, 1 / 2, 2, 2]
1329
+ num_heads = 6
1330
+ window_sizes = [8, 16, 16, 16, 8]
1331
+ drop_path_rate = 0.1
1332
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
1333
+
1334
+ self.tran = nn.ModuleList()
1335
+ for i, depth in enumerate(depths):
1336
+ res = int(res * ratios[i])
1337
+ if ratios[i] < 1:
1338
+ merge = PatchMerging(dim, dim, down=int(1 / ratios[i]))
1339
+ elif ratios[i] > 1:
1340
+ merge = PatchUpsampling(dim, dim, up=ratios[i])
1341
+ else:
1342
+ merge = None
1343
+ self.tran.append(
1344
+ BasicLayer(
1345
+ dim=dim,
1346
+ input_resolution=[res, res],
1347
+ depth=depth,
1348
+ num_heads=num_heads,
1349
+ window_size=window_sizes[i],
1350
+ drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
1351
+ downsample=merge,
1352
+ )
1353
+ )
1354
+
1355
+ # global style
1356
+ down_conv = []
1357
+ for i in range(int(np.log2(16))):
1358
+ down_conv.append(
1359
+ Conv2dLayer(
1360
+ in_channels=dim,
1361
+ out_channels=dim,
1362
+ kernel_size=3,
1363
+ down=2,
1364
+ activation=activation,
1365
+ )
1366
+ )
1367
+ down_conv.append(nn.AdaptiveAvgPool2d((1, 1)))
1368
+ self.down_conv = nn.Sequential(*down_conv)
1369
+ self.to_style = FullyConnectedLayer(
1370
+ in_features=dim, out_features=dim * 2, activation=activation
1371
+ )
1372
+ self.ws_style = FullyConnectedLayer(
1373
+ in_features=w_dim, out_features=dim, activation=activation
1374
+ )
1375
+ self.to_square = FullyConnectedLayer(
1376
+ in_features=dim, out_features=16 * 16, activation=activation
1377
+ )
1378
+
1379
+ style_dim = dim * 3
1380
+ self.dec_conv = nn.ModuleList()
1381
+ for i in range(down_time): # from 64 to input size
1382
+ res = res * 2
1383
+ self.dec_conv.append(
1384
+ DecStyleBlock(
1385
+ res,
1386
+ dim,
1387
+ dim,
1388
+ activation,
1389
+ style_dim,
1390
+ use_noise,
1391
+ demodulate,
1392
+ img_channels,
1393
+ )
1394
+ )
1395
+
1396
+ def forward(self, images_in, masks_in, ws, noise_mode="random"):
1397
+ x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1)
1398
+
1399
+ skips = []
1400
+ x, mask = self.conv_first(x, masks_in) # input size
1401
+ skips.append(x)
1402
+ for i, block in enumerate(self.enc_conv): # input size to 64
1403
+ x, mask = block(x, mask)
1404
+ if i != len(self.enc_conv) - 1:
1405
+ skips.append(x)
1406
+
1407
+ x_size = x.size()[-2:]
1408
+ x = feature2token(x)
1409
+ mask = feature2token(mask)
1410
+ mid = len(self.tran) // 2
1411
+ for i, block in enumerate(self.tran): # 64 to 16
1412
+ if i < mid:
1413
+ x, x_size, mask = block(x, x_size, mask)
1414
+ skips.append(x)
1415
+ elif i > mid:
1416
+ x, x_size, mask = block(x, x_size, None)
1417
+ x = x + skips[mid - i]
1418
+ else:
1419
+ x, x_size, mask = block(x, x_size, None)
1420
+
1421
+ mul_map = torch.ones_like(x) * 0.5
1422
+ mul_map = F.dropout(mul_map, training=True).to(x.device)
1423
+ ws = self.ws_style(ws[:, -1]).to(x.device)
1424
+ add_n = self.to_square(ws).unsqueeze(1).to(x.device)
1425
+ add_n = (
1426
+ F.interpolate(
1427
+ add_n, size=x.size(1), mode="linear", align_corners=False
1428
+ )
1429
+ .squeeze(1)
1430
+ .unsqueeze(-1)
1431
+ ).to(x.device)
1432
+ x = x * mul_map + add_n * (1 - mul_map)
1433
+ gs = self.to_style(
1434
+ self.down_conv(token2feature(x, x_size)).flatten(start_dim=1)
1435
+ ).to(x.device)
1436
+ style = torch.cat([gs, ws], dim=1)
1437
+
1438
+ x = token2feature(x, x_size).contiguous()
1439
+ img = None
1440
+ for i, block in enumerate(self.dec_conv):
1441
+ x, img = block(
1442
+ x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode
1443
+ )
1444
+
1445
+ # ensemble
1446
+ img = img * (1 - masks_in) + images_in * masks_in
1447
+
1448
+ return img
1449
+
1450
+
1451
+ class SynthesisNet(nn.Module):
1452
+ def __init__(
1453
+ self,
1454
+ w_dim, # Intermediate latent (W) dimensionality.
1455
+ img_resolution, # Output image resolution.
1456
+ img_channels=3, # Number of color channels.
1457
+ channel_base=32768, # Overall multiplier for the number of channels.
1458
+ channel_decay=1.0,
1459
+ channel_max=512, # Maximum number of channels in any layer.
1460
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
1461
+ drop_rate=0.5,
1462
+ use_noise=False,
1463
+ demodulate=True,
1464
+ ):
1465
+ super().__init__()
1466
+ resolution_log2 = int(np.log2(img_resolution))
1467
+ assert img_resolution == 2**resolution_log2 and img_resolution >= 4
1468
+
1469
+ self.num_layers = resolution_log2 * 2 - 3 * 2
1470
+ self.img_resolution = img_resolution
1471
+ self.resolution_log2 = resolution_log2
1472
+
1473
+ # first stage
1474
+ self.first_stage = FirstStage(
1475
+ img_channels,
1476
+ img_resolution=img_resolution,
1477
+ w_dim=w_dim,
1478
+ use_noise=False,
1479
+ demodulate=demodulate,
1480
+ )
1481
+
1482
+ # second stage
1483
+ self.enc = Encoder(
1484
+ resolution_log2, img_channels, activation, patch_size=5, channels=16
1485
+ )
1486
+ self.to_square = FullyConnectedLayer(
1487
+ in_features=w_dim, out_features=16 * 16, activation=activation
1488
+ )
1489
+ self.to_style = ToStyle(
1490
+ in_channels=nf(4),
1491
+ out_channels=nf(2) * 2,
1492
+ activation=activation,
1493
+ drop_rate=drop_rate,
1494
+ )
1495
+ style_dim = w_dim + nf(2) * 2
1496
+ self.dec = Decoder(
1497
+ resolution_log2, activation, style_dim, use_noise, demodulate, img_channels
1498
+ )
1499
+
1500
+ def forward(self, images_in, masks_in, ws, noise_mode="random", return_stg1=False):
1501
+ out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode)
1502
+
1503
+ # encoder
1504
+ x = images_in * masks_in + out_stg1 * (1 - masks_in)
1505
+ x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1)
1506
+ E_features = self.enc(x)
1507
+
1508
+ fea_16 = E_features[4].to(x.device)
1509
+ mul_map = torch.ones_like(fea_16) * 0.5
1510
+ mul_map = F.dropout(mul_map, training=True).to(x.device)
1511
+ add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1)
1512
+ add_n = F.interpolate(
1513
+ add_n, size=fea_16.size()[-2:], mode="bilinear", align_corners=False
1514
+ ).to(x.device)
1515
+ fea_16 = fea_16 * mul_map + add_n * (1 - mul_map)
1516
+ E_features[4] = fea_16
1517
+
1518
+ # style
1519
+ gs = self.to_style(fea_16).to(x.device)
1520
+
1521
+ # decoder
1522
+ img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode).to(x.device)
1523
+
1524
+ # ensemble
1525
+ img = img * (1 - masks_in) + images_in * masks_in
1526
+
1527
+ if not return_stg1:
1528
+ return img
1529
+ else:
1530
+ return img, out_stg1
1531
+
1532
+
1533
+ class Generator(nn.Module):
1534
+ def __init__(
1535
+ self,
1536
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
1537
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
1538
+ w_dim, # Intermediate latent (W) dimensionality.
1539
+ img_resolution, # resolution of generated image
1540
+ img_channels, # Number of input color channels.
1541
+ synthesis_kwargs={}, # Arguments for SynthesisNetwork.
1542
+ mapping_kwargs={}, # Arguments for MappingNetwork.
1543
+ ):
1544
+ super().__init__()
1545
+ self.z_dim = z_dim
1546
+ self.c_dim = c_dim
1547
+ self.w_dim = w_dim
1548
+ self.img_resolution = img_resolution
1549
+ self.img_channels = img_channels
1550
+
1551
+ self.synthesis = SynthesisNet(
1552
+ w_dim=w_dim,
1553
+ img_resolution=img_resolution,
1554
+ img_channels=img_channels,
1555
+ **synthesis_kwargs,
1556
+ )
1557
+ self.mapping = MappingNet(
1558
+ z_dim=z_dim,
1559
+ c_dim=c_dim,
1560
+ w_dim=w_dim,
1561
+ num_ws=self.synthesis.num_layers,
1562
+ **mapping_kwargs,
1563
+ )
1564
+
1565
+ def forward(
1566
+ self,
1567
+ images_in,
1568
+ masks_in,
1569
+ z,
1570
+ c,
1571
+ truncation_psi=1,
1572
+ truncation_cutoff=None,
1573
+ skip_w_avg_update=False,
1574
+ noise_mode="none",
1575
+ return_stg1=False,
1576
+ ):
1577
+ ws = self.mapping(
1578
+ z,
1579
+ c,
1580
+ truncation_psi=truncation_psi,
1581
+ truncation_cutoff=truncation_cutoff,
1582
+ skip_w_avg_update=skip_w_avg_update,
1583
+ )
1584
+ img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode)
1585
+ return img
1586
+
1587
+
1588
+ class MAT(nn.Module):
1589
+ def __init__(self, state_dict):
1590
+ super(MAT, self).__init__()
1591
+ self.model_arch = "MAT"
1592
+ self.sub_type = "Inpaint"
1593
+ self.in_nc = 3
1594
+ self.out_nc = 3
1595
+ self.scale = 1
1596
+
1597
+ self.supports_fp16 = False
1598
+ self.supports_bf16 = True
1599
+
1600
+ self.min_size = 512
1601
+ self.pad_mod = 512
1602
+ self.pad_to_square = True
1603
+
1604
+ seed = 240 # pick up a random number
1605
+ random.seed(seed)
1606
+ np.random.seed(seed)
1607
+ torch.manual_seed(seed)
1608
+
1609
+ self.model = Generator(
1610
+ z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3
1611
+ )
1612
+ self.z = torch.from_numpy(np.random.randn(1, self.model.z_dim)) # [1., 512]
1613
+ self.label = torch.zeros([1, self.model.c_dim])
1614
+ self.state = {
1615
+ k.replace("synthesis", "model.synthesis").replace(
1616
+ "mapping", "model.mapping"
1617
+ ): v
1618
+ for k, v in state_dict.items()
1619
+ }
1620
+ self.load_state_dict(self.state, strict=False)
1621
+
1622
+ def forward(self, image, mask):
1623
+ """Input images and output images have same size
1624
+ images: [H, W, C] RGB
1625
+ masks: [H, W] mask area == 255
1626
+ return: BGR IMAGE
1627
+ """
1628
+
1629
+ image = image * 2 - 1 # [0, 1] -> [-1, 1]
1630
+ mask = 1 - mask
1631
+
1632
+ output = self.model(
1633
+ image, mask, self.z, self.label, truncation_psi=1, noise_mode="none"
1634
+ )
1635
+
1636
+ return output * 0.5 + 0.5
comfy_extras/chainner_models/architecture/RRDB.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import functools
5
+ import math
6
+ import re
7
+ from collections import OrderedDict
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from . import block as B
14
+
15
+
16
+ # Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
17
+ # Which enhanced stuff that was already here
18
+ class RRDBNet(nn.Module):
19
+ def __init__(
20
+ self,
21
+ state_dict,
22
+ norm=None,
23
+ act: str = "leakyrelu",
24
+ upsampler: str = "upconv",
25
+ mode: B.ConvMode = "CNA",
26
+ ) -> None:
27
+ """
28
+ ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
29
+ By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
30
+ and Chen Change Loy.
31
+ This is old-arch Residual in Residual Dense Block Network and is not
32
+ the newest revision that's available at github.com/xinntao/ESRGAN.
33
+ This is on purpose, the newest Network has severely limited the
34
+ potential use of the Network with no benefits.
35
+ This network supports model files from both new and old-arch.
36
+ Args:
37
+ norm: Normalization layer
38
+ act: Activation layer
39
+ upsampler: Upsample layer. upconv, pixel_shuffle
40
+ mode: Convolution mode
41
+ """
42
+ super(RRDBNet, self).__init__()
43
+ self.model_arch = "ESRGAN"
44
+ self.sub_type = "SR"
45
+
46
+ self.state = state_dict
47
+ self.norm = norm
48
+ self.act = act
49
+ self.upsampler = upsampler
50
+ self.mode = mode
51
+
52
+ self.state_map = {
53
+ # currently supports old, new, and newer RRDBNet arch models
54
+ # ESRGAN, BSRGAN/RealSR, Real-ESRGAN
55
+ "model.0.weight": ("conv_first.weight",),
56
+ "model.0.bias": ("conv_first.bias",),
57
+ "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
58
+ "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
59
+ r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
60
+ r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
61
+ r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
62
+ ),
63
+ }
64
+ if "params_ema" in self.state:
65
+ self.state = self.state["params_ema"]
66
+ # self.model_arch = "RealESRGAN"
67
+ self.num_blocks = self.get_num_blocks()
68
+ self.plus = any("conv1x1" in k for k in self.state.keys())
69
+ if self.plus:
70
+ self.model_arch = "ESRGAN+"
71
+
72
+ self.state = self.new_to_old_arch(self.state)
73
+
74
+ self.key_arr = list(self.state.keys())
75
+
76
+ self.in_nc: int = self.state[self.key_arr[0]].shape[1]
77
+ self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
78
+
79
+ self.scale: int = self.get_scale()
80
+ self.num_filters: int = self.state[self.key_arr[0]].shape[0]
81
+
82
+ self.supports_fp16 = True
83
+ self.supports_bfp16 = True
84
+ self.min_size_restriction = None
85
+
86
+ # Detect if pixelunshuffle was used (Real-ESRGAN)
87
+ if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
88
+ self.in_nc / 4,
89
+ self.in_nc / 16,
90
+ ):
91
+ self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
92
+ else:
93
+ self.shuffle_factor = None
94
+
95
+ upsample_block = {
96
+ "upconv": B.upconv_block,
97
+ "pixel_shuffle": B.pixelshuffle_block,
98
+ }.get(self.upsampler)
99
+ if upsample_block is None:
100
+ raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")
101
+
102
+ if self.scale == 3:
103
+ upsample_blocks = upsample_block(
104
+ in_nc=self.num_filters,
105
+ out_nc=self.num_filters,
106
+ upscale_factor=3,
107
+ act_type=self.act,
108
+ )
109
+ else:
110
+ upsample_blocks = [
111
+ upsample_block(
112
+ in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act
113
+ )
114
+ for _ in range(int(math.log(self.scale, 2)))
115
+ ]
116
+
117
+ self.model = B.sequential(
118
+ # fea conv
119
+ B.conv_block(
120
+ in_nc=self.in_nc,
121
+ out_nc=self.num_filters,
122
+ kernel_size=3,
123
+ norm_type=None,
124
+ act_type=None,
125
+ ),
126
+ B.ShortcutBlock(
127
+ B.sequential(
128
+ # rrdb blocks
129
+ *[
130
+ B.RRDB(
131
+ nf=self.num_filters,
132
+ kernel_size=3,
133
+ gc=32,
134
+ stride=1,
135
+ bias=True,
136
+ pad_type="zero",
137
+ norm_type=self.norm,
138
+ act_type=self.act,
139
+ mode="CNA",
140
+ plus=self.plus,
141
+ )
142
+ for _ in range(self.num_blocks)
143
+ ],
144
+ # lr conv
145
+ B.conv_block(
146
+ in_nc=self.num_filters,
147
+ out_nc=self.num_filters,
148
+ kernel_size=3,
149
+ norm_type=self.norm,
150
+ act_type=None,
151
+ mode=self.mode,
152
+ ),
153
+ )
154
+ ),
155
+ *upsample_blocks,
156
+ # hr_conv0
157
+ B.conv_block(
158
+ in_nc=self.num_filters,
159
+ out_nc=self.num_filters,
160
+ kernel_size=3,
161
+ norm_type=None,
162
+ act_type=self.act,
163
+ ),
164
+ # hr_conv1
165
+ B.conv_block(
166
+ in_nc=self.num_filters,
167
+ out_nc=self.out_nc,
168
+ kernel_size=3,
169
+ norm_type=None,
170
+ act_type=None,
171
+ ),
172
+ )
173
+
174
+ # Adjust these properties for calculations outside of the model
175
+ if self.shuffle_factor:
176
+ self.in_nc //= self.shuffle_factor**2
177
+ self.scale //= self.shuffle_factor
178
+
179
+ self.load_state_dict(self.state, strict=False)
180
+
181
+ def new_to_old_arch(self, state):
182
+ """Convert a new-arch model state dictionary to an old-arch dictionary."""
183
+ if "params_ema" in state:
184
+ state = state["params_ema"]
185
+
186
+ if "conv_first.weight" not in state:
187
+ # model is already old arch, this is a loose check, but should be sufficient
188
+ return state
189
+
190
+ # add nb to state keys
191
+ for kind in ("weight", "bias"):
192
+ self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
193
+ f"model.1.sub./NB/.{kind}"
194
+ ]
195
+ del self.state_map[f"model.1.sub./NB/.{kind}"]
196
+
197
+ old_state = OrderedDict()
198
+ for old_key, new_keys in self.state_map.items():
199
+ for new_key in new_keys:
200
+ if r"\1" in old_key:
201
+ for k, v in state.items():
202
+ sub = re.sub(new_key, old_key, k)
203
+ if sub != k:
204
+ old_state[sub] = v
205
+ else:
206
+ if new_key in state:
207
+ old_state[old_key] = state[new_key]
208
+
209
+ # upconv layers
210
+ max_upconv = 0
211
+ for key in state.keys():
212
+ match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key)
213
+ if match is not None:
214
+ _, key_num, key_type = match.groups()
215
+ old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key]
216
+ max_upconv = max(max_upconv, int(key_num) * 3)
217
+
218
+ # final layers
219
+ for key in state.keys():
220
+ if key in ("HRconv.weight", "conv_hr.weight"):
221
+ old_state[f"model.{max_upconv + 2}.weight"] = state[key]
222
+ elif key in ("HRconv.bias", "conv_hr.bias"):
223
+ old_state[f"model.{max_upconv + 2}.bias"] = state[key]
224
+ elif key in ("conv_last.weight",):
225
+ old_state[f"model.{max_upconv + 4}.weight"] = state[key]
226
+ elif key in ("conv_last.bias",):
227
+ old_state[f"model.{max_upconv + 4}.bias"] = state[key]
228
+
229
+ # Sort by first numeric value of each layer
230
+ def compare(item1, item2):
231
+ parts1 = item1.split(".")
232
+ parts2 = item2.split(".")
233
+ int1 = int(parts1[1])
234
+ int2 = int(parts2[1])
235
+ return int1 - int2
236
+
237
+ sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
238
+
239
+ # Rebuild the output dict in the right order
240
+ out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
241
+
242
+ return out_dict
243
+
244
+ def get_scale(self, min_part: int = 6) -> int:
245
+ n = 0
246
+ for part in list(self.state):
247
+ parts = part.split(".")[1:]
248
+ if len(parts) == 2:
249
+ part_num = int(parts[0])
250
+ if part_num > min_part and parts[1] == "weight":
251
+ n += 1
252
+ return 2**n
253
+
254
+ def get_num_blocks(self) -> int:
255
+ nbs = []
256
+ state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
257
+ r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
258
+ )
259
+ for state_key in state_keys:
260
+ for k in self.state:
261
+ m = re.search(state_key, k)
262
+ if m:
263
+ nbs.append(int(m.group(1)))
264
+ if nbs:
265
+ break
266
+ return max(*nbs) + 1
267
+
268
+ def forward(self, x):
269
+ if self.shuffle_factor:
270
+ _, _, h, w = x.size()
271
+ mod_pad_h = (
272
+ self.shuffle_factor - h % self.shuffle_factor
273
+ ) % self.shuffle_factor
274
+ mod_pad_w = (
275
+ self.shuffle_factor - w % self.shuffle_factor
276
+ ) % self.shuffle_factor
277
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
278
+ x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
279
+ x = self.model(x)
280
+ return x[:, :, : h * self.scale, : w * self.scale]
281
+ return self.model(x)
comfy_extras/chainner_models/architecture/SPSR.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from . import block as B
11
+
12
+
13
+ class Get_gradient_nopadding(nn.Module):
14
+ def __init__(self):
15
+ super(Get_gradient_nopadding, self).__init__()
16
+ kernel_v = [[0, -1, 0], [0, 0, 0], [0, 1, 0]]
17
+ kernel_h = [[0, 0, 0], [-1, 0, 1], [0, 0, 0]]
18
+ kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
19
+ kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
20
+ self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False) # type: ignore
21
+
22
+ self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False) # type: ignore
23
+
24
+ def forward(self, x):
25
+ x_list = []
26
+ for i in range(x.shape[1]):
27
+ x_i = x[:, i]
28
+ x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
29
+ x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
30
+ x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
31
+ x_list.append(x_i)
32
+
33
+ x = torch.cat(x_list, dim=1)
34
+
35
+ return x
36
+
37
+
38
+ class SPSRNet(nn.Module):
39
+ def __init__(
40
+ self,
41
+ state_dict,
42
+ norm=None,
43
+ act: str = "leakyrelu",
44
+ upsampler: str = "upconv",
45
+ mode: B.ConvMode = "CNA",
46
+ ):
47
+ super(SPSRNet, self).__init__()
48
+ self.model_arch = "SPSR"
49
+ self.sub_type = "SR"
50
+
51
+ self.state = state_dict
52
+ self.norm = norm
53
+ self.act = act
54
+ self.upsampler = upsampler
55
+ self.mode = mode
56
+
57
+ self.num_blocks = self.get_num_blocks()
58
+
59
+ self.in_nc: int = self.state["model.0.weight"].shape[1]
60
+ self.out_nc: int = self.state["f_HR_conv1.0.bias"].shape[0]
61
+
62
+ self.scale = self.get_scale(4)
63
+ print(self.scale)
64
+ self.num_filters: int = self.state["model.0.weight"].shape[0]
65
+
66
+ self.supports_fp16 = True
67
+ self.supports_bfp16 = True
68
+ self.min_size_restriction = None
69
+
70
+ n_upscale = int(math.log(self.scale, 2))
71
+ if self.scale == 3:
72
+ n_upscale = 1
73
+
74
+ fea_conv = B.conv_block(
75
+ self.in_nc, self.num_filters, kernel_size=3, norm_type=None, act_type=None
76
+ )
77
+ rb_blocks = [
78
+ B.RRDB(
79
+ self.num_filters,
80
+ kernel_size=3,
81
+ gc=32,
82
+ stride=1,
83
+ bias=True,
84
+ pad_type="zero",
85
+ norm_type=norm,
86
+ act_type=act,
87
+ mode="CNA",
88
+ )
89
+ for _ in range(self.num_blocks)
90
+ ]
91
+ LR_conv = B.conv_block(
92
+ self.num_filters,
93
+ self.num_filters,
94
+ kernel_size=3,
95
+ norm_type=norm,
96
+ act_type=None,
97
+ mode=mode,
98
+ )
99
+
100
+ if upsampler == "upconv":
101
+ upsample_block = B.upconv_block
102
+ elif upsampler == "pixelshuffle":
103
+ upsample_block = B.pixelshuffle_block
104
+ else:
105
+ raise NotImplementedError(f"upsample mode [{upsampler}] is not found")
106
+ if self.scale == 3:
107
+ a_upsampler = upsample_block(
108
+ self.num_filters, self.num_filters, 3, act_type=act
109
+ )
110
+ else:
111
+ a_upsampler = [
112
+ upsample_block(self.num_filters, self.num_filters, act_type=act)
113
+ for _ in range(n_upscale)
114
+ ]
115
+ self.HR_conv0_new = B.conv_block(
116
+ self.num_filters,
117
+ self.num_filters,
118
+ kernel_size=3,
119
+ norm_type=None,
120
+ act_type=act,
121
+ )
122
+ self.HR_conv1_new = B.conv_block(
123
+ self.num_filters,
124
+ self.num_filters,
125
+ kernel_size=3,
126
+ norm_type=None,
127
+ act_type=None,
128
+ )
129
+
130
+ self.model = B.sequential(
131
+ fea_conv,
132
+ B.ShortcutBlockSPSR(B.sequential(*rb_blocks, LR_conv)),
133
+ *a_upsampler,
134
+ self.HR_conv0_new,
135
+ )
136
+
137
+ self.get_g_nopadding = Get_gradient_nopadding()
138
+
139
+ self.b_fea_conv = B.conv_block(
140
+ self.in_nc, self.num_filters, kernel_size=3, norm_type=None, act_type=None
141
+ )
142
+
143
+ self.b_concat_1 = B.conv_block(
144
+ 2 * self.num_filters,
145
+ self.num_filters,
146
+ kernel_size=3,
147
+ norm_type=None,
148
+ act_type=None,
149
+ )
150
+ self.b_block_1 = B.RRDB(
151
+ self.num_filters * 2,
152
+ kernel_size=3,
153
+ gc=32,
154
+ stride=1,
155
+ bias=True,
156
+ pad_type="zero",
157
+ norm_type=norm,
158
+ act_type=act,
159
+ mode="CNA",
160
+ )
161
+
162
+ self.b_concat_2 = B.conv_block(
163
+ 2 * self.num_filters,
164
+ self.num_filters,
165
+ kernel_size=3,
166
+ norm_type=None,
167
+ act_type=None,
168
+ )
169
+ self.b_block_2 = B.RRDB(
170
+ self.num_filters * 2,
171
+ kernel_size=3,
172
+ gc=32,
173
+ stride=1,
174
+ bias=True,
175
+ pad_type="zero",
176
+ norm_type=norm,
177
+ act_type=act,
178
+ mode="CNA",
179
+ )
180
+
181
+ self.b_concat_3 = B.conv_block(
182
+ 2 * self.num_filters,
183
+ self.num_filters,
184
+ kernel_size=3,
185
+ norm_type=None,
186
+ act_type=None,
187
+ )
188
+ self.b_block_3 = B.RRDB(
189
+ self.num_filters * 2,
190
+ kernel_size=3,
191
+ gc=32,
192
+ stride=1,
193
+ bias=True,
194
+ pad_type="zero",
195
+ norm_type=norm,
196
+ act_type=act,
197
+ mode="CNA",
198
+ )
199
+
200
+ self.b_concat_4 = B.conv_block(
201
+ 2 * self.num_filters,
202
+ self.num_filters,
203
+ kernel_size=3,
204
+ norm_type=None,
205
+ act_type=None,
206
+ )
207
+ self.b_block_4 = B.RRDB(
208
+ self.num_filters * 2,
209
+ kernel_size=3,
210
+ gc=32,
211
+ stride=1,
212
+ bias=True,
213
+ pad_type="zero",
214
+ norm_type=norm,
215
+ act_type=act,
216
+ mode="CNA",
217
+ )
218
+
219
+ self.b_LR_conv = B.conv_block(
220
+ self.num_filters,
221
+ self.num_filters,
222
+ kernel_size=3,
223
+ norm_type=norm,
224
+ act_type=None,
225
+ mode=mode,
226
+ )
227
+
228
+ if upsampler == "upconv":
229
+ upsample_block = B.upconv_block
230
+ elif upsampler == "pixelshuffle":
231
+ upsample_block = B.pixelshuffle_block
232
+ else:
233
+ raise NotImplementedError(f"upsample mode [{upsampler}] is not found")
234
+ if self.scale == 3:
235
+ b_upsampler = upsample_block(
236
+ self.num_filters, self.num_filters, 3, act_type=act
237
+ )
238
+ else:
239
+ b_upsampler = [
240
+ upsample_block(self.num_filters, self.num_filters, act_type=act)
241
+ for _ in range(n_upscale)
242
+ ]
243
+
244
+ b_HR_conv0 = B.conv_block(
245
+ self.num_filters,
246
+ self.num_filters,
247
+ kernel_size=3,
248
+ norm_type=None,
249
+ act_type=act,
250
+ )
251
+ b_HR_conv1 = B.conv_block(
252
+ self.num_filters,
253
+ self.num_filters,
254
+ kernel_size=3,
255
+ norm_type=None,
256
+ act_type=None,
257
+ )
258
+
259
+ self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
260
+
261
+ self.conv_w = B.conv_block(
262
+ self.num_filters, self.out_nc, kernel_size=1, norm_type=None, act_type=None
263
+ )
264
+
265
+ self.f_concat = B.conv_block(
266
+ self.num_filters * 2,
267
+ self.num_filters,
268
+ kernel_size=3,
269
+ norm_type=None,
270
+ act_type=None,
271
+ )
272
+
273
+ self.f_block = B.RRDB(
274
+ self.num_filters * 2,
275
+ kernel_size=3,
276
+ gc=32,
277
+ stride=1,
278
+ bias=True,
279
+ pad_type="zero",
280
+ norm_type=norm,
281
+ act_type=act,
282
+ mode="CNA",
283
+ )
284
+
285
+ self.f_HR_conv0 = B.conv_block(
286
+ self.num_filters,
287
+ self.num_filters,
288
+ kernel_size=3,
289
+ norm_type=None,
290
+ act_type=act,
291
+ )
292
+ self.f_HR_conv1 = B.conv_block(
293
+ self.num_filters, self.out_nc, kernel_size=3, norm_type=None, act_type=None
294
+ )
295
+
296
+ self.load_state_dict(self.state, strict=False)
297
+
298
+ def get_scale(self, min_part: int = 4) -> int:
299
+ n = 0
300
+ for part in list(self.state):
301
+ parts = part.split(".")
302
+ if len(parts) == 3:
303
+ part_num = int(parts[1])
304
+ if part_num > min_part and parts[0] == "model" and parts[2] == "weight":
305
+ n += 1
306
+ return 2**n
307
+
308
+ def get_num_blocks(self) -> int:
309
+ nb = 0
310
+ for part in list(self.state):
311
+ parts = part.split(".")
312
+ n_parts = len(parts)
313
+ if n_parts == 5 and parts[2] == "sub":
314
+ nb = int(parts[3])
315
+ return nb
316
+
317
+ def forward(self, x):
318
+ x_grad = self.get_g_nopadding(x)
319
+ x = self.model[0](x)
320
+
321
+ x, block_list = self.model[1](x)
322
+
323
+ x_ori = x
324
+ for i in range(5):
325
+ x = block_list[i](x)
326
+ x_fea1 = x
327
+
328
+ for i in range(5):
329
+ x = block_list[i + 5](x)
330
+ x_fea2 = x
331
+
332
+ for i in range(5):
333
+ x = block_list[i + 10](x)
334
+ x_fea3 = x
335
+
336
+ for i in range(5):
337
+ x = block_list[i + 15](x)
338
+ x_fea4 = x
339
+
340
+ x = block_list[20:](x)
341
+ # short cut
342
+ x = x_ori + x
343
+ x = self.model[2:](x)
344
+ x = self.HR_conv1_new(x)
345
+
346
+ x_b_fea = self.b_fea_conv(x_grad)
347
+ x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
348
+
349
+ x_cat_1 = self.b_block_1(x_cat_1)
350
+ x_cat_1 = self.b_concat_1(x_cat_1)
351
+
352
+ x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
353
+
354
+ x_cat_2 = self.b_block_2(x_cat_2)
355
+ x_cat_2 = self.b_concat_2(x_cat_2)
356
+
357
+ x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
358
+
359
+ x_cat_3 = self.b_block_3(x_cat_3)
360
+ x_cat_3 = self.b_concat_3(x_cat_3)
361
+
362
+ x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
363
+
364
+ x_cat_4 = self.b_block_4(x_cat_4)
365
+ x_cat_4 = self.b_concat_4(x_cat_4)
366
+
367
+ x_cat_4 = self.b_LR_conv(x_cat_4)
368
+
369
+ # short cut
370
+ x_cat_4 = x_cat_4 + x_b_fea
371
+ x_branch = self.b_module(x_cat_4)
372
+
373
+ # x_out_branch = self.conv_w(x_branch)
374
+ ########
375
+ x_branch_d = x_branch
376
+ x_f_cat = torch.cat([x_branch_d, x], dim=1)
377
+ x_f_cat = self.f_block(x_f_cat)
378
+ x_out = self.f_concat(x_f_cat)
379
+ x_out = self.f_HR_conv0(x_out)
380
+ x_out = self.f_HR_conv1(x_out)
381
+
382
+ #########
383
+ # return x_out_branch, x_out, x_grad
384
+ return x_out
comfy_extras/chainner_models/architecture/SRVGG.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import math
5
+
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class SRVGGNetCompact(nn.Module):
11
+ """A compact VGG-style network structure for super-resolution.
12
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
13
+ conducted on the HR feature space.
14
+ Args:
15
+ num_in_ch (int): Channel number of inputs. Default: 3.
16
+ num_out_ch (int): Channel number of outputs. Default: 3.
17
+ num_feat (int): Channel number of intermediate features. Default: 64.
18
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
19
+ upscale (int): Upsampling factor. Default: 4.
20
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ state_dict,
26
+ act_type: str = "prelu",
27
+ ):
28
+ super(SRVGGNetCompact, self).__init__()
29
+ self.model_arch = "SRVGG (RealESRGAN)"
30
+ self.sub_type = "SR"
31
+
32
+ self.act_type = act_type
33
+
34
+ self.state = state_dict
35
+
36
+ if "params" in self.state:
37
+ self.state = self.state["params"]
38
+
39
+ self.key_arr = list(self.state.keys())
40
+
41
+ self.in_nc = self.get_in_nc()
42
+ self.num_feat = self.get_num_feats()
43
+ self.num_conv = self.get_num_conv()
44
+ self.out_nc = self.in_nc # :(
45
+ self.pixelshuffle_shape = None # Defined in get_scale()
46
+ self.scale = self.get_scale()
47
+
48
+ self.supports_fp16 = True
49
+ self.supports_bfp16 = True
50
+ self.min_size_restriction = None
51
+
52
+ self.body = nn.ModuleList()
53
+ # the first conv
54
+ self.body.append(nn.Conv2d(self.in_nc, self.num_feat, 3, 1, 1))
55
+ # the first activation
56
+ if act_type == "relu":
57
+ activation = nn.ReLU(inplace=True)
58
+ elif act_type == "prelu":
59
+ activation = nn.PReLU(num_parameters=self.num_feat)
60
+ elif act_type == "leakyrelu":
61
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
62
+ self.body.append(activation) # type: ignore
63
+
64
+ # the body structure
65
+ for _ in range(self.num_conv):
66
+ self.body.append(nn.Conv2d(self.num_feat, self.num_feat, 3, 1, 1))
67
+ # activation
68
+ if act_type == "relu":
69
+ activation = nn.ReLU(inplace=True)
70
+ elif act_type == "prelu":
71
+ activation = nn.PReLU(num_parameters=self.num_feat)
72
+ elif act_type == "leakyrelu":
73
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
74
+ self.body.append(activation) # type: ignore
75
+
76
+ # the last conv
77
+ self.body.append(nn.Conv2d(self.num_feat, self.pixelshuffle_shape, 3, 1, 1)) # type: ignore
78
+ # upsample
79
+ self.upsampler = nn.PixelShuffle(self.scale)
80
+
81
+ self.load_state_dict(self.state, strict=False)
82
+
83
+ def get_num_conv(self) -> int:
84
+ return (int(self.key_arr[-1].split(".")[1]) - 2) // 2
85
+
86
+ def get_num_feats(self) -> int:
87
+ return self.state[self.key_arr[0]].shape[0]
88
+
89
+ def get_in_nc(self) -> int:
90
+ return self.state[self.key_arr[0]].shape[1]
91
+
92
+ def get_scale(self) -> int:
93
+ self.pixelshuffle_shape = self.state[self.key_arr[-1]].shape[0]
94
+ # Assume out_nc is the same as in_nc
95
+ # I cant think of a better way to do that
96
+ self.out_nc = self.in_nc
97
+ scale = math.sqrt(self.pixelshuffle_shape / self.out_nc)
98
+ if scale - int(scale) > 0:
99
+ print(
100
+ "out_nc is probably different than in_nc, scale calculation might be wrong"
101
+ )
102
+ scale = int(scale)
103
+ return scale
104
+
105
+ def forward(self, x):
106
+ out = x
107
+ for i in range(0, len(self.body)):
108
+ out = self.body[i](out)
109
+
110
+ out = self.upsampler(out)
111
+ # add the nearest upsampled image, so that the network learns the residual
112
+ base = F.interpolate(x, scale_factor=self.scale, mode="nearest")
113
+ out += base
114
+ return out
comfy_extras/chainner_models/architecture/SwiftSRGAN.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/Koushik0901/Swift-SRGAN/blob/master/swift-srgan/models.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class SeperableConv2d(nn.Module):
8
+ def __init__(
9
+ self, in_channels, out_channels, kernel_size, stride=1, padding=1, bias=True
10
+ ):
11
+ super(SeperableConv2d, self).__init__()
12
+ self.depthwise = nn.Conv2d(
13
+ in_channels,
14
+ in_channels,
15
+ kernel_size=kernel_size,
16
+ stride=stride,
17
+ groups=in_channels,
18
+ bias=bias,
19
+ padding=padding,
20
+ )
21
+ self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
22
+
23
+ def forward(self, x):
24
+ return self.pointwise(self.depthwise(x))
25
+
26
+
27
+ class ConvBlock(nn.Module):
28
+ def __init__(
29
+ self,
30
+ in_channels,
31
+ out_channels,
32
+ use_act=True,
33
+ use_bn=True,
34
+ discriminator=False,
35
+ **kwargs,
36
+ ):
37
+ super(ConvBlock, self).__init__()
38
+
39
+ self.use_act = use_act
40
+ self.cnn = SeperableConv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
41
+ self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
42
+ self.act = (
43
+ nn.LeakyReLU(0.2, inplace=True)
44
+ if discriminator
45
+ else nn.PReLU(num_parameters=out_channels)
46
+ )
47
+
48
+ def forward(self, x):
49
+ return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))
50
+
51
+
52
+ class UpsampleBlock(nn.Module):
53
+ def __init__(self, in_channels, scale_factor):
54
+ super(UpsampleBlock, self).__init__()
55
+
56
+ self.conv = SeperableConv2d(
57
+ in_channels,
58
+ in_channels * scale_factor**2,
59
+ kernel_size=3,
60
+ stride=1,
61
+ padding=1,
62
+ )
63
+ self.ps = nn.PixelShuffle(
64
+ scale_factor
65
+ ) # (in_channels * 4, H, W) -> (in_channels, H*2, W*2)
66
+ self.act = nn.PReLU(num_parameters=in_channels)
67
+
68
+ def forward(self, x):
69
+ return self.act(self.ps(self.conv(x)))
70
+
71
+
72
+ class ResidualBlock(nn.Module):
73
+ def __init__(self, in_channels):
74
+ super(ResidualBlock, self).__init__()
75
+
76
+ self.block1 = ConvBlock(
77
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
78
+ )
79
+ self.block2 = ConvBlock(
80
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1, use_act=False
81
+ )
82
+
83
+ def forward(self, x):
84
+ out = self.block1(x)
85
+ out = self.block2(out)
86
+ return out + x
87
+
88
+
89
+ class Generator(nn.Module):
90
+ """Swift-SRGAN Generator
91
+ Args:
92
+ in_channels (int): number of input image channels.
93
+ num_channels (int): number of hidden channels.
94
+ num_blocks (int): number of residual blocks.
95
+ upscale_factor (int): factor to upscale the image [2x, 4x, 8x].
96
+ Returns:
97
+ torch.Tensor: super resolution image
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ state_dict,
103
+ ):
104
+ super(Generator, self).__init__()
105
+ self.model_arch = "Swift-SRGAN"
106
+ self.sub_type = "SR"
107
+ self.state = state_dict
108
+ if "model" in self.state:
109
+ self.state = self.state["model"]
110
+
111
+ self.in_nc: int = self.state["initial.cnn.depthwise.weight"].shape[0]
112
+ self.out_nc: int = self.state["final_conv.pointwise.weight"].shape[0]
113
+ self.num_filters: int = self.state["initial.cnn.pointwise.weight"].shape[0]
114
+ self.num_blocks = len(
115
+ set([x.split(".")[1] for x in self.state.keys() if "residual" in x])
116
+ )
117
+ self.scale: int = 2 ** len(
118
+ set([x.split(".")[1] for x in self.state.keys() if "upsampler" in x])
119
+ )
120
+
121
+ in_channels = self.in_nc
122
+ num_channels = self.num_filters
123
+ num_blocks = self.num_blocks
124
+ upscale_factor = self.scale
125
+
126
+ self.supports_fp16 = True
127
+ self.supports_bfp16 = True
128
+ self.min_size_restriction = None
129
+
130
+ self.initial = ConvBlock(
131
+ in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False
132
+ )
133
+ self.residual = nn.Sequential(
134
+ *[ResidualBlock(num_channels) for _ in range(num_blocks)]
135
+ )
136
+ self.convblock = ConvBlock(
137
+ num_channels,
138
+ num_channels,
139
+ kernel_size=3,
140
+ stride=1,
141
+ padding=1,
142
+ use_act=False,
143
+ )
144
+ self.upsampler = nn.Sequential(
145
+ *[
146
+ UpsampleBlock(num_channels, scale_factor=2)
147
+ for _ in range(upscale_factor // 2)
148
+ ]
149
+ )
150
+ self.final_conv = SeperableConv2d(
151
+ num_channels, in_channels, kernel_size=9, stride=1, padding=4
152
+ )
153
+
154
+ self.load_state_dict(self.state, strict=False)
155
+
156
+ def forward(self, x):
157
+ initial = self.initial(x)
158
+ x = self.residual(initial)
159
+ x = self.convblock(x) + initial
160
+ x = self.upsampler(x)
161
+ return (torch.tanh(self.final_conv(x)) + 1) / 2
comfy_extras/chainner_models/architecture/Swin2SR.py ADDED
@@ -0,0 +1,1377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # -----------------------------------------------------------------------------------
3
+ # Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/2209.11345
4
+ # Written by Conde and Choi et al.
5
+ # From: https://raw.githubusercontent.com/mv-lab/swin2sr/main/models/network_swin2sr.py
6
+ # -----------------------------------------------------------------------------------
7
+
8
+ import math
9
+ import re
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint as checkpoint
16
+
17
+ # Originally from the timm package
18
+ from .timm.drop import DropPath
19
+ from .timm.helpers import to_2tuple
20
+ from .timm.weight_init import trunc_normal_
21
+
22
+
23
+ class Mlp(nn.Module):
24
+ def __init__(
25
+ self,
26
+ in_features,
27
+ hidden_features=None,
28
+ out_features=None,
29
+ act_layer=nn.GELU,
30
+ drop=0.0,
31
+ ):
32
+ super().__init__()
33
+ out_features = out_features or in_features
34
+ hidden_features = hidden_features or in_features
35
+ self.fc1 = nn.Linear(in_features, hidden_features)
36
+ self.act = act_layer()
37
+ self.fc2 = nn.Linear(hidden_features, out_features)
38
+ self.drop = nn.Dropout(drop)
39
+
40
+ def forward(self, x):
41
+ x = self.fc1(x)
42
+ x = self.act(x)
43
+ x = self.drop(x)
44
+ x = self.fc2(x)
45
+ x = self.drop(x)
46
+ return x
47
+
48
+
49
+ def window_partition(x, window_size):
50
+ """
51
+ Args:
52
+ x: (B, H, W, C)
53
+ window_size (int): window size
54
+ Returns:
55
+ windows: (num_windows*B, window_size, window_size, C)
56
+ """
57
+ B, H, W, C = x.shape
58
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
59
+ windows = (
60
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
61
+ )
62
+ return windows
63
+
64
+
65
+ def window_reverse(windows, window_size, H, W):
66
+ """
67
+ Args:
68
+ windows: (num_windows*B, window_size, window_size, C)
69
+ window_size (int): Window size
70
+ H (int): Height of image
71
+ W (int): Width of image
72
+ Returns:
73
+ x: (B, H, W, C)
74
+ """
75
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
76
+ x = windows.view(
77
+ B, H // window_size, W // window_size, window_size, window_size, -1
78
+ )
79
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
80
+ return x
81
+
82
+
83
+ class WindowAttention(nn.Module):
84
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
85
+ It supports both of shifted and non-shifted window.
86
+ Args:
87
+ dim (int): Number of input channels.
88
+ window_size (tuple[int]): The height and width of the window.
89
+ num_heads (int): Number of attention heads.
90
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
91
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
92
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
93
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ dim,
99
+ window_size,
100
+ num_heads,
101
+ qkv_bias=True,
102
+ attn_drop=0.0,
103
+ proj_drop=0.0,
104
+ pretrained_window_size=[0, 0],
105
+ ):
106
+ super().__init__()
107
+ self.dim = dim
108
+ self.window_size = window_size # Wh, Ww
109
+ self.pretrained_window_size = pretrained_window_size
110
+ self.num_heads = num_heads
111
+
112
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) # type: ignore
113
+
114
+ # mlp to generate continuous relative position bias
115
+ self.cpb_mlp = nn.Sequential(
116
+ nn.Linear(2, 512, bias=True),
117
+ nn.ReLU(inplace=True),
118
+ nn.Linear(512, num_heads, bias=False),
119
+ )
120
+
121
+ # get relative_coords_table
122
+ relative_coords_h = torch.arange(
123
+ -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32
124
+ )
125
+ relative_coords_w = torch.arange(
126
+ -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32
127
+ )
128
+ relative_coords_table = (
129
+ torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
130
+ .permute(1, 2, 0)
131
+ .contiguous()
132
+ .unsqueeze(0)
133
+ ) # 1, 2*Wh-1, 2*Ww-1, 2
134
+ if pretrained_window_size[0] > 0:
135
+ relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
136
+ relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
137
+ else:
138
+ relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
139
+ relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
140
+ relative_coords_table *= 8 # normalize to -8, 8
141
+ relative_coords_table = (
142
+ torch.sign(relative_coords_table)
143
+ * torch.log2(torch.abs(relative_coords_table) + 1.0)
144
+ / np.log2(8)
145
+ )
146
+
147
+ self.register_buffer("relative_coords_table", relative_coords_table)
148
+
149
+ # get pair-wise relative position index for each token inside the window
150
+ coords_h = torch.arange(self.window_size[0])
151
+ coords_w = torch.arange(self.window_size[1])
152
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
153
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
154
+ relative_coords = (
155
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
156
+ ) # 2, Wh*Ww, Wh*Ww
157
+ relative_coords = relative_coords.permute(
158
+ 1, 2, 0
159
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
160
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
161
+ relative_coords[:, :, 1] += self.window_size[1] - 1
162
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
163
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
164
+ self.register_buffer("relative_position_index", relative_position_index)
165
+
166
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
167
+ if qkv_bias:
168
+ self.q_bias = nn.Parameter(torch.zeros(dim)) # type: ignore
169
+ self.v_bias = nn.Parameter(torch.zeros(dim)) # type: ignore
170
+ else:
171
+ self.q_bias = None
172
+ self.v_bias = None
173
+ self.attn_drop = nn.Dropout(attn_drop)
174
+ self.proj = nn.Linear(dim, dim)
175
+ self.proj_drop = nn.Dropout(proj_drop)
176
+ self.softmax = nn.Softmax(dim=-1)
177
+
178
+ def forward(self, x, mask=None):
179
+ """
180
+ Args:
181
+ x: input features with shape of (num_windows*B, N, C)
182
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
183
+ """
184
+ B_, N, C = x.shape
185
+ qkv_bias = None
186
+ if self.q_bias is not None:
187
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # type: ignore
188
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
189
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
190
+ q, k, v = (
191
+ qkv[0],
192
+ qkv[1],
193
+ qkv[2],
194
+ ) # make torchscript happy (cannot use tensor as tuple)
195
+
196
+ # cosine attention
197
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
198
+ logit_scale = torch.clamp(
199
+ self.logit_scale,
200
+ max=torch.log(torch.tensor(1.0 / 0.01)).to(self.logit_scale.device),
201
+ ).exp()
202
+ attn = attn * logit_scale
203
+
204
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(
205
+ -1, self.num_heads
206
+ )
207
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( # type: ignore
208
+ self.window_size[0] * self.window_size[1],
209
+ self.window_size[0] * self.window_size[1],
210
+ -1,
211
+ ) # Wh*Ww,Wh*Ww,nH
212
+ relative_position_bias = relative_position_bias.permute(
213
+ 2, 0, 1
214
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
215
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
216
+ attn = attn + relative_position_bias.unsqueeze(0)
217
+
218
+ if mask is not None:
219
+ nW = mask.shape[0]
220
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
221
+ 1
222
+ ).unsqueeze(0)
223
+ attn = attn.view(-1, self.num_heads, N, N)
224
+ attn = self.softmax(attn)
225
+ else:
226
+ attn = self.softmax(attn)
227
+
228
+ attn = self.attn_drop(attn)
229
+
230
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
231
+ x = self.proj(x)
232
+ x = self.proj_drop(x)
233
+ return x
234
+
235
+ def extra_repr(self) -> str:
236
+ return (
237
+ f"dim={self.dim}, window_size={self.window_size}, "
238
+ f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
239
+ )
240
+
241
+ def flops(self, N):
242
+ # calculate flops for 1 window with token length of N
243
+ flops = 0
244
+ # qkv = self.qkv(x)
245
+ flops += N * self.dim * 3 * self.dim
246
+ # attn = (q @ k.transpose(-2, -1))
247
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
248
+ # x = (attn @ v)
249
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
250
+ # x = self.proj(x)
251
+ flops += N * self.dim * self.dim
252
+ return flops
253
+
254
+
255
+ class SwinTransformerBlock(nn.Module):
256
+ r"""Swin Transformer Block.
257
+ Args:
258
+ dim (int): Number of input channels.
259
+ input_resolution (tuple[int]): Input resulotion.
260
+ num_heads (int): Number of attention heads.
261
+ window_size (int): Window size.
262
+ shift_size (int): Shift size for SW-MSA.
263
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
264
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
265
+ drop (float, optional): Dropout rate. Default: 0.0
266
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
267
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
268
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
269
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
270
+ pretrained_window_size (int): Window size in pre-training.
271
+ """
272
+
273
+ def __init__(
274
+ self,
275
+ dim,
276
+ input_resolution,
277
+ num_heads,
278
+ window_size=7,
279
+ shift_size=0,
280
+ mlp_ratio=4.0,
281
+ qkv_bias=True,
282
+ drop=0.0,
283
+ attn_drop=0.0,
284
+ drop_path=0.0,
285
+ act_layer=nn.GELU,
286
+ norm_layer=nn.LayerNorm,
287
+ pretrained_window_size=0,
288
+ ):
289
+ super().__init__()
290
+ self.dim = dim
291
+ self.input_resolution = input_resolution
292
+ self.num_heads = num_heads
293
+ self.window_size = window_size
294
+ self.shift_size = shift_size
295
+ self.mlp_ratio = mlp_ratio
296
+ if min(self.input_resolution) <= self.window_size:
297
+ # if window size is larger than input resolution, we don't partition windows
298
+ self.shift_size = 0
299
+ self.window_size = min(self.input_resolution)
300
+ assert (
301
+ 0 <= self.shift_size < self.window_size
302
+ ), "shift_size must in 0-window_size"
303
+
304
+ self.norm1 = norm_layer(dim)
305
+ self.attn = WindowAttention(
306
+ dim,
307
+ window_size=to_2tuple(self.window_size),
308
+ num_heads=num_heads,
309
+ qkv_bias=qkv_bias,
310
+ attn_drop=attn_drop,
311
+ proj_drop=drop,
312
+ pretrained_window_size=to_2tuple(pretrained_window_size),
313
+ )
314
+
315
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
316
+ self.norm2 = norm_layer(dim)
317
+ mlp_hidden_dim = int(dim * mlp_ratio)
318
+ self.mlp = Mlp(
319
+ in_features=dim,
320
+ hidden_features=mlp_hidden_dim,
321
+ act_layer=act_layer,
322
+ drop=drop,
323
+ )
324
+
325
+ if self.shift_size > 0:
326
+ attn_mask = self.calculate_mask(self.input_resolution)
327
+ else:
328
+ attn_mask = None
329
+
330
+ self.register_buffer("attn_mask", attn_mask)
331
+
332
+ def calculate_mask(self, x_size):
333
+ # calculate attention mask for SW-MSA
334
+ H, W = x_size
335
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
336
+ h_slices = (
337
+ slice(0, -self.window_size),
338
+ slice(-self.window_size, -self.shift_size),
339
+ slice(-self.shift_size, None),
340
+ )
341
+ w_slices = (
342
+ slice(0, -self.window_size),
343
+ slice(-self.window_size, -self.shift_size),
344
+ slice(-self.shift_size, None),
345
+ )
346
+ cnt = 0
347
+ for h in h_slices:
348
+ for w in w_slices:
349
+ img_mask[:, h, w, :] = cnt
350
+ cnt += 1
351
+
352
+ mask_windows = window_partition(
353
+ img_mask, self.window_size
354
+ ) # nW, window_size, window_size, 1
355
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
356
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
357
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
358
+ attn_mask == 0, float(0.0)
359
+ )
360
+
361
+ return attn_mask
362
+
363
+ def forward(self, x, x_size):
364
+ H, W = x_size
365
+ B, L, C = x.shape
366
+ # assert L == H * W, "input feature has wrong size"
367
+
368
+ shortcut = x
369
+ x = x.view(B, H, W, C)
370
+
371
+ # cyclic shift
372
+ if self.shift_size > 0:
373
+ shifted_x = torch.roll(
374
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
375
+ )
376
+ else:
377
+ shifted_x = x
378
+
379
+ # partition windows
380
+ x_windows = window_partition(
381
+ shifted_x, self.window_size
382
+ ) # nW*B, window_size, window_size, C
383
+ x_windows = x_windows.view(
384
+ -1, self.window_size * self.window_size, C
385
+ ) # nW*B, window_size*window_size, C
386
+
387
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
388
+ if self.input_resolution == x_size:
389
+ attn_windows = self.attn(
390
+ x_windows, mask=self.attn_mask
391
+ ) # nW*B, window_size*window_size, C
392
+ else:
393
+ attn_windows = self.attn(
394
+ x_windows, mask=self.calculate_mask(x_size).to(x.device)
395
+ )
396
+
397
+ # merge windows
398
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
399
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
400
+
401
+ # reverse cyclic shift
402
+ if self.shift_size > 0:
403
+ x = torch.roll(
404
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
405
+ )
406
+ else:
407
+ x = shifted_x
408
+ x = x.view(B, H * W, C)
409
+ x = shortcut + self.drop_path(self.norm1(x))
410
+
411
+ # FFN
412
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
413
+
414
+ return x
415
+
416
+ def extra_repr(self) -> str:
417
+ return (
418
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
419
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
420
+ )
421
+
422
+ def flops(self):
423
+ flops = 0
424
+ H, W = self.input_resolution
425
+ # norm1
426
+ flops += self.dim * H * W
427
+ # W-MSA/SW-MSA
428
+ nW = H * W / self.window_size / self.window_size
429
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
430
+ # mlp
431
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
432
+ # norm2
433
+ flops += self.dim * H * W
434
+ return flops
435
+
436
+
437
+ class PatchMerging(nn.Module):
438
+ r"""Patch Merging Layer.
439
+ Args:
440
+ input_resolution (tuple[int]): Resolution of input feature.
441
+ dim (int): Number of input channels.
442
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
443
+ """
444
+
445
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
446
+ super().__init__()
447
+ self.input_resolution = input_resolution
448
+ self.dim = dim
449
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
450
+ self.norm = norm_layer(2 * dim)
451
+
452
+ def forward(self, x):
453
+ """
454
+ x: B, H*W, C
455
+ """
456
+ H, W = self.input_resolution
457
+ B, L, C = x.shape
458
+ assert L == H * W, "input feature has wrong size"
459
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
460
+
461
+ x = x.view(B, H, W, C)
462
+
463
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
464
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
465
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
466
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
467
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
468
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
469
+
470
+ x = self.reduction(x)
471
+ x = self.norm(x)
472
+
473
+ return x
474
+
475
+ def extra_repr(self) -> str:
476
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
477
+
478
+ def flops(self):
479
+ H, W = self.input_resolution
480
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
481
+ flops += H * W * self.dim // 2
482
+ return flops
483
+
484
+
485
+ class BasicLayer(nn.Module):
486
+ """A basic Swin Transformer layer for one stage.
487
+ Args:
488
+ dim (int): Number of input channels.
489
+ input_resolution (tuple[int]): Input resolution.
490
+ depth (int): Number of blocks.
491
+ num_heads (int): Number of attention heads.
492
+ window_size (int): Local window size.
493
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
494
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
495
+ drop (float, optional): Dropout rate. Default: 0.0
496
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
497
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
498
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
499
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
500
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
501
+ pretrained_window_size (int): Local window size in pre-training.
502
+ """
503
+
504
+ def __init__(
505
+ self,
506
+ dim,
507
+ input_resolution,
508
+ depth,
509
+ num_heads,
510
+ window_size,
511
+ mlp_ratio=4.0,
512
+ qkv_bias=True,
513
+ drop=0.0,
514
+ attn_drop=0.0,
515
+ drop_path=0.0,
516
+ norm_layer=nn.LayerNorm,
517
+ downsample=None,
518
+ use_checkpoint=False,
519
+ pretrained_window_size=0,
520
+ ):
521
+ super().__init__()
522
+ self.dim = dim
523
+ self.input_resolution = input_resolution
524
+ self.depth = depth
525
+ self.use_checkpoint = use_checkpoint
526
+
527
+ # build blocks
528
+ self.blocks = nn.ModuleList(
529
+ [
530
+ SwinTransformerBlock(
531
+ dim=dim,
532
+ input_resolution=input_resolution,
533
+ num_heads=num_heads,
534
+ window_size=window_size,
535
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
536
+ mlp_ratio=mlp_ratio,
537
+ qkv_bias=qkv_bias,
538
+ drop=drop,
539
+ attn_drop=attn_drop,
540
+ drop_path=drop_path[i]
541
+ if isinstance(drop_path, list)
542
+ else drop_path,
543
+ norm_layer=norm_layer,
544
+ pretrained_window_size=pretrained_window_size,
545
+ )
546
+ for i in range(depth)
547
+ ]
548
+ )
549
+
550
+ # patch merging layer
551
+ if downsample is not None:
552
+ self.downsample = downsample(
553
+ input_resolution, dim=dim, norm_layer=norm_layer
554
+ )
555
+ else:
556
+ self.downsample = None
557
+
558
+ def forward(self, x, x_size):
559
+ for blk in self.blocks:
560
+ if self.use_checkpoint:
561
+ x = checkpoint.checkpoint(blk, x, x_size)
562
+ else:
563
+ x = blk(x, x_size)
564
+ if self.downsample is not None:
565
+ x = self.downsample(x)
566
+ return x
567
+
568
+ def extra_repr(self) -> str:
569
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
570
+
571
+ def flops(self):
572
+ flops = 0
573
+ for blk in self.blocks:
574
+ flops += blk.flops() # type: ignore
575
+ if self.downsample is not None:
576
+ flops += self.downsample.flops()
577
+ return flops
578
+
579
+ def _init_respostnorm(self):
580
+ for blk in self.blocks:
581
+ nn.init.constant_(blk.norm1.bias, 0) # type: ignore
582
+ nn.init.constant_(blk.norm1.weight, 0) # type: ignore
583
+ nn.init.constant_(blk.norm2.bias, 0) # type: ignore
584
+ nn.init.constant_(blk.norm2.weight, 0) # type: ignore
585
+
586
+
587
+ class PatchEmbed(nn.Module):
588
+ r"""Image to Patch Embedding
589
+ Args:
590
+ img_size (int): Image size. Default: 224.
591
+ patch_size (int): Patch token size. Default: 4.
592
+ in_chans (int): Number of input image channels. Default: 3.
593
+ embed_dim (int): Number of linear projection output channels. Default: 96.
594
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
595
+ """
596
+
597
+ def __init__(
598
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
599
+ ):
600
+ super().__init__()
601
+ img_size = to_2tuple(img_size)
602
+ patch_size = to_2tuple(patch_size)
603
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # type: ignore
604
+ self.img_size = img_size
605
+ self.patch_size = patch_size
606
+ self.patches_resolution = patches_resolution
607
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
608
+
609
+ self.in_chans = in_chans
610
+ self.embed_dim = embed_dim
611
+
612
+ self.proj = nn.Conv2d(
613
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size # type: ignore
614
+ )
615
+ if norm_layer is not None:
616
+ self.norm = norm_layer(embed_dim)
617
+ else:
618
+ self.norm = None
619
+
620
+ def forward(self, x):
621
+ B, C, H, W = x.shape
622
+ # FIXME look at relaxing size constraints
623
+ # assert H == self.img_size[0] and W == self.img_size[1],
624
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
625
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
626
+ if self.norm is not None:
627
+ x = self.norm(x)
628
+ return x
629
+
630
+ def flops(self):
631
+ Ho, Wo = self.patches_resolution
632
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) # type: ignore
633
+ if self.norm is not None:
634
+ flops += Ho * Wo * self.embed_dim
635
+ return flops
636
+
637
+
638
+ class RSTB(nn.Module):
639
+ """Residual Swin Transformer Block (RSTB).
640
+
641
+ Args:
642
+ dim (int): Number of input channels.
643
+ input_resolution (tuple[int]): Input resolution.
644
+ depth (int): Number of blocks.
645
+ num_heads (int): Number of attention heads.
646
+ window_size (int): Local window size.
647
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
648
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
649
+ drop (float, optional): Dropout rate. Default: 0.0
650
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
651
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
652
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
653
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
654
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
655
+ img_size: Input image size.
656
+ patch_size: Patch size.
657
+ resi_connection: The convolutional block before residual connection.
658
+ """
659
+
660
+ def __init__(
661
+ self,
662
+ dim,
663
+ input_resolution,
664
+ depth,
665
+ num_heads,
666
+ window_size,
667
+ mlp_ratio=4.0,
668
+ qkv_bias=True,
669
+ drop=0.0,
670
+ attn_drop=0.0,
671
+ drop_path=0.0,
672
+ norm_layer=nn.LayerNorm,
673
+ downsample=None,
674
+ use_checkpoint=False,
675
+ img_size=224,
676
+ patch_size=4,
677
+ resi_connection="1conv",
678
+ ):
679
+ super(RSTB, self).__init__()
680
+
681
+ self.dim = dim
682
+ self.input_resolution = input_resolution
683
+
684
+ self.residual_group = BasicLayer(
685
+ dim=dim,
686
+ input_resolution=input_resolution,
687
+ depth=depth,
688
+ num_heads=num_heads,
689
+ window_size=window_size,
690
+ mlp_ratio=mlp_ratio,
691
+ qkv_bias=qkv_bias,
692
+ drop=drop,
693
+ attn_drop=attn_drop,
694
+ drop_path=drop_path,
695
+ norm_layer=norm_layer,
696
+ downsample=downsample,
697
+ use_checkpoint=use_checkpoint,
698
+ )
699
+
700
+ if resi_connection == "1conv":
701
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
702
+ elif resi_connection == "3conv":
703
+ # to save parameters and memory
704
+ self.conv = nn.Sequential(
705
+ nn.Conv2d(dim, dim // 4, 3, 1, 1),
706
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
707
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
708
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
709
+ nn.Conv2d(dim // 4, dim, 3, 1, 1),
710
+ )
711
+
712
+ self.patch_embed = PatchEmbed(
713
+ img_size=img_size,
714
+ patch_size=patch_size,
715
+ in_chans=dim,
716
+ embed_dim=dim,
717
+ norm_layer=None,
718
+ )
719
+
720
+ self.patch_unembed = PatchUnEmbed(
721
+ img_size=img_size,
722
+ patch_size=patch_size,
723
+ in_chans=dim,
724
+ embed_dim=dim,
725
+ norm_layer=None,
726
+ )
727
+
728
+ def forward(self, x, x_size):
729
+ return (
730
+ self.patch_embed(
731
+ self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
732
+ )
733
+ + x
734
+ )
735
+
736
+ def flops(self):
737
+ flops = 0
738
+ flops += self.residual_group.flops()
739
+ H, W = self.input_resolution
740
+ flops += H * W * self.dim * self.dim * 9
741
+ flops += self.patch_embed.flops()
742
+ flops += self.patch_unembed.flops()
743
+
744
+ return flops
745
+
746
+
747
+ class PatchUnEmbed(nn.Module):
748
+ r"""Image to Patch Unembedding
749
+
750
+ Args:
751
+ img_size (int): Image size. Default: 224.
752
+ patch_size (int): Patch token size. Default: 4.
753
+ in_chans (int): Number of input image channels. Default: 3.
754
+ embed_dim (int): Number of linear projection output channels. Default: 96.
755
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
756
+ """
757
+
758
+ def __init__(
759
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
760
+ ):
761
+ super().__init__()
762
+ img_size = to_2tuple(img_size)
763
+ patch_size = to_2tuple(patch_size)
764
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # type: ignore
765
+ self.img_size = img_size
766
+ self.patch_size = patch_size
767
+ self.patches_resolution = patches_resolution
768
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
769
+
770
+ self.in_chans = in_chans
771
+ self.embed_dim = embed_dim
772
+
773
+ def forward(self, x, x_size):
774
+ B, HW, C = x.shape
775
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
776
+ return x
777
+
778
+ def flops(self):
779
+ flops = 0
780
+ return flops
781
+
782
+
783
+ class Upsample(nn.Sequential):
784
+ """Upsample module.
785
+
786
+ Args:
787
+ scale (int): Scale factor. Supported scales: 2^n and 3.
788
+ num_feat (int): Channel number of intermediate features.
789
+ """
790
+
791
+ def __init__(self, scale, num_feat):
792
+ m = []
793
+ if (scale & (scale - 1)) == 0: # scale = 2^n
794
+ for _ in range(int(math.log(scale, 2))):
795
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
796
+ m.append(nn.PixelShuffle(2))
797
+ elif scale == 3:
798
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
799
+ m.append(nn.PixelShuffle(3))
800
+ else:
801
+ raise ValueError(
802
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
803
+ )
804
+ super(Upsample, self).__init__(*m)
805
+
806
+
807
+ class Upsample_hf(nn.Sequential):
808
+ """Upsample module.
809
+
810
+ Args:
811
+ scale (int): Scale factor. Supported scales: 2^n and 3.
812
+ num_feat (int): Channel number of intermediate features.
813
+ """
814
+
815
+ def __init__(self, scale, num_feat):
816
+ m = []
817
+ if (scale & (scale - 1)) == 0: # scale = 2^n
818
+ for _ in range(int(math.log(scale, 2))):
819
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
820
+ m.append(nn.PixelShuffle(2))
821
+ elif scale == 3:
822
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
823
+ m.append(nn.PixelShuffle(3))
824
+ else:
825
+ raise ValueError(
826
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
827
+ )
828
+ super(Upsample_hf, self).__init__(*m)
829
+
830
+
831
+ class UpsampleOneStep(nn.Sequential):
832
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
833
+ Used in lightweight SR to save parameters.
834
+
835
+ Args:
836
+ scale (int): Scale factor. Supported scales: 2^n and 3.
837
+ num_feat (int): Channel number of intermediate features.
838
+
839
+ """
840
+
841
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
842
+ self.num_feat = num_feat
843
+ self.input_resolution = input_resolution
844
+ m = []
845
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
846
+ m.append(nn.PixelShuffle(scale))
847
+ super(UpsampleOneStep, self).__init__(*m)
848
+
849
+ def flops(self):
850
+ H, W = self.input_resolution # type: ignore
851
+ flops = H * W * self.num_feat * 3 * 9
852
+ return flops
853
+
854
+
855
+ class Swin2SR(nn.Module):
856
+ r"""Swin2SR
857
+ A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
858
+
859
+ Args:
860
+ img_size (int | tuple(int)): Input image size. Default 64
861
+ patch_size (int | tuple(int)): Patch size. Default: 1
862
+ in_chans (int): Number of input image channels. Default: 3
863
+ embed_dim (int): Patch embedding dimension. Default: 96
864
+ depths (tuple(int)): Depth of each Swin Transformer layer.
865
+ num_heads (tuple(int)): Number of attention heads in different layers.
866
+ window_size (int): Window size. Default: 7
867
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
868
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
869
+ drop_rate (float): Dropout rate. Default: 0
870
+ attn_drop_rate (float): Attention dropout rate. Default: 0
871
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
872
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
873
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
874
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
875
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
876
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
877
+ img_range: Image range. 1. or 255.
878
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
879
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
880
+ """
881
+
882
+ def __init__(
883
+ self,
884
+ state_dict,
885
+ **kwargs,
886
+ ):
887
+ super(Swin2SR, self).__init__()
888
+
889
+ # Defaults
890
+ img_size = 128
891
+ patch_size = 1
892
+ in_chans = 3
893
+ embed_dim = 96
894
+ depths = [6, 6, 6, 6]
895
+ num_heads = [6, 6, 6, 6]
896
+ window_size = 7
897
+ mlp_ratio = 4.0
898
+ qkv_bias = True
899
+ drop_rate = 0.0
900
+ attn_drop_rate = 0.0
901
+ drop_path_rate = 0.1
902
+ norm_layer = nn.LayerNorm
903
+ ape = False
904
+ patch_norm = True
905
+ use_checkpoint = False
906
+ upscale = 2
907
+ img_range = 1.0
908
+ upsampler = ""
909
+ resi_connection = "1conv"
910
+ num_in_ch = in_chans
911
+ num_out_ch = in_chans
912
+ num_feat = 64
913
+
914
+ self.model_arch = "Swin2SR"
915
+ self.sub_type = "SR"
916
+ self.state = state_dict
917
+ if "params_ema" in self.state:
918
+ self.state = self.state["params_ema"]
919
+ elif "params" in self.state:
920
+ self.state = self.state["params"]
921
+
922
+ state_keys = self.state.keys()
923
+
924
+ if "conv_before_upsample.0.weight" in state_keys:
925
+ if "conv_aux.weight" in state_keys:
926
+ upsampler = "pixelshuffle_aux"
927
+ elif "conv_up1.weight" in state_keys:
928
+ upsampler = "nearest+conv"
929
+ else:
930
+ upsampler = "pixelshuffle"
931
+ supports_fp16 = False
932
+ elif "upsample.0.weight" in state_keys:
933
+ upsampler = "pixelshuffledirect"
934
+ else:
935
+ upsampler = ""
936
+
937
+ num_feat = (
938
+ self.state.get("conv_before_upsample.0.weight", None).shape[1]
939
+ if self.state.get("conv_before_upsample.weight", None)
940
+ else 64
941
+ )
942
+
943
+ num_in_ch = self.state["conv_first.weight"].shape[1]
944
+ in_chans = num_in_ch
945
+ if "conv_last.weight" in state_keys:
946
+ num_out_ch = self.state["conv_last.weight"].shape[0]
947
+ else:
948
+ num_out_ch = num_in_ch
949
+
950
+ upscale = 1
951
+ if upsampler == "nearest+conv":
952
+ upsample_keys = [
953
+ x for x in state_keys if "conv_up" in x and "bias" not in x
954
+ ]
955
+
956
+ for upsample_key in upsample_keys:
957
+ upscale *= 2
958
+ elif upsampler == "pixelshuffle" or upsampler == "pixelshuffle_aux":
959
+ upsample_keys = [
960
+ x
961
+ for x in state_keys
962
+ if "upsample" in x and "conv" not in x and "bias" not in x
963
+ ]
964
+ for upsample_key in upsample_keys:
965
+ shape = self.state[upsample_key].shape[0]
966
+ upscale *= math.sqrt(shape // num_feat)
967
+ upscale = int(upscale)
968
+ elif upsampler == "pixelshuffledirect":
969
+ upscale = int(
970
+ math.sqrt(self.state["upsample.0.bias"].shape[0] // num_out_ch)
971
+ )
972
+
973
+ max_layer_num = 0
974
+ max_block_num = 0
975
+ for key in state_keys:
976
+ result = re.match(
977
+ r"layers.(\d*).residual_group.blocks.(\d*).norm1.weight", key
978
+ )
979
+ if result:
980
+ layer_num, block_num = result.groups()
981
+ max_layer_num = max(max_layer_num, int(layer_num))
982
+ max_block_num = max(max_block_num, int(block_num))
983
+
984
+ depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]
985
+
986
+ if (
987
+ "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
988
+ in state_keys
989
+ ):
990
+ num_heads_num = self.state[
991
+ "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
992
+ ].shape[-1]
993
+ num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
994
+ else:
995
+ num_heads = depths
996
+
997
+ embed_dim = self.state["conv_first.weight"].shape[0]
998
+
999
+ mlp_ratio = float(
1000
+ self.state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0]
1001
+ / embed_dim
1002
+ )
1003
+
1004
+ # TODO: could actually count the layers, but this should do
1005
+ if "layers.0.conv.4.weight" in state_keys:
1006
+ resi_connection = "3conv"
1007
+ else:
1008
+ resi_connection = "1conv"
1009
+
1010
+ window_size = int(
1011
+ math.sqrt(
1012
+ self.state[
1013
+ "layers.0.residual_group.blocks.0.attn.relative_position_index"
1014
+ ].shape[0]
1015
+ )
1016
+ )
1017
+
1018
+ if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
1019
+ img_size = int(
1020
+ math.sqrt(
1021
+ self.state["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
1022
+ )
1023
+ * window_size
1024
+ )
1025
+
1026
+ # The JPEG models are the only ones with window-size 7, and they also use this range
1027
+ img_range = 255.0 if window_size == 7 else 1.0
1028
+
1029
+ self.in_nc = num_in_ch
1030
+ self.out_nc = num_out_ch
1031
+ self.num_feat = num_feat
1032
+ self.embed_dim = embed_dim
1033
+ self.num_heads = num_heads
1034
+ self.depths = depths
1035
+ self.window_size = window_size
1036
+ self.mlp_ratio = mlp_ratio
1037
+ self.scale = upscale
1038
+ self.upsampler = upsampler
1039
+ self.img_size = img_size
1040
+ self.img_range = img_range
1041
+ self.resi_connection = resi_connection
1042
+
1043
+ self.supports_fp16 = False # Too much weirdness to support this at the moment
1044
+ self.supports_bfp16 = True
1045
+ self.min_size_restriction = 16
1046
+
1047
+ ## END AUTO DETECTION
1048
+
1049
+ if in_chans == 3:
1050
+ rgb_mean = (0.4488, 0.4371, 0.4040)
1051
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
1052
+ else:
1053
+ self.mean = torch.zeros(1, 1, 1, 1)
1054
+ self.upscale = upscale
1055
+ self.upsampler = upsampler
1056
+ self.window_size = window_size
1057
+
1058
+ #####################################################################################################
1059
+ ################################### 1, shallow feature extraction ###################################
1060
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
1061
+
1062
+ #####################################################################################################
1063
+ ################################### 2, deep feature extraction ######################################
1064
+ self.num_layers = len(depths)
1065
+ self.embed_dim = embed_dim
1066
+ self.ape = ape
1067
+ self.patch_norm = patch_norm
1068
+ self.num_features = embed_dim
1069
+ self.mlp_ratio = mlp_ratio
1070
+
1071
+ # split image into non-overlapping patches
1072
+ self.patch_embed = PatchEmbed(
1073
+ img_size=img_size,
1074
+ patch_size=patch_size,
1075
+ in_chans=embed_dim,
1076
+ embed_dim=embed_dim,
1077
+ norm_layer=norm_layer if self.patch_norm else None,
1078
+ )
1079
+ num_patches = self.patch_embed.num_patches
1080
+ patches_resolution = self.patch_embed.patches_resolution
1081
+ self.patches_resolution = patches_resolution
1082
+
1083
+ # merge non-overlapping patches into image
1084
+ self.patch_unembed = PatchUnEmbed(
1085
+ img_size=img_size,
1086
+ patch_size=patch_size,
1087
+ in_chans=embed_dim,
1088
+ embed_dim=embed_dim,
1089
+ norm_layer=norm_layer if self.patch_norm else None,
1090
+ )
1091
+
1092
+ # absolute position embedding
1093
+ if self.ape:
1094
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # type: ignore
1095
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
1096
+
1097
+ self.pos_drop = nn.Dropout(p=drop_rate)
1098
+
1099
+ # stochastic depth
1100
+ dpr = [
1101
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
1102
+ ] # stochastic depth decay rule
1103
+
1104
+ # build Residual Swin Transformer blocks (RSTB)
1105
+ self.layers = nn.ModuleList()
1106
+ for i_layer in range(self.num_layers):
1107
+ layer = RSTB(
1108
+ dim=embed_dim,
1109
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
1110
+ depth=depths[i_layer],
1111
+ num_heads=num_heads[i_layer],
1112
+ window_size=window_size,
1113
+ mlp_ratio=self.mlp_ratio,
1114
+ qkv_bias=qkv_bias,
1115
+ drop=drop_rate,
1116
+ attn_drop=attn_drop_rate,
1117
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], # type: ignore # no impact on SR results
1118
+ norm_layer=norm_layer,
1119
+ downsample=None,
1120
+ use_checkpoint=use_checkpoint,
1121
+ img_size=img_size,
1122
+ patch_size=patch_size,
1123
+ resi_connection=resi_connection,
1124
+ )
1125
+ self.layers.append(layer)
1126
+
1127
+ if self.upsampler == "pixelshuffle_hf":
1128
+ self.layers_hf = nn.ModuleList()
1129
+ for i_layer in range(self.num_layers):
1130
+ layer = RSTB(
1131
+ dim=embed_dim,
1132
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
1133
+ depth=depths[i_layer],
1134
+ num_heads=num_heads[i_layer],
1135
+ window_size=window_size,
1136
+ mlp_ratio=self.mlp_ratio,
1137
+ qkv_bias=qkv_bias,
1138
+ drop=drop_rate,
1139
+ attn_drop=attn_drop_rate,
1140
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], # type: ignore # no impact on SR results # type: ignore
1141
+ norm_layer=norm_layer,
1142
+ downsample=None,
1143
+ use_checkpoint=use_checkpoint,
1144
+ img_size=img_size,
1145
+ patch_size=patch_size,
1146
+ resi_connection=resi_connection,
1147
+ )
1148
+ self.layers_hf.append(layer)
1149
+
1150
+ self.norm = norm_layer(self.num_features)
1151
+
1152
+ # build the last conv layer in deep feature extraction
1153
+ if resi_connection == "1conv":
1154
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1155
+ elif resi_connection == "3conv":
1156
+ # to save parameters and memory
1157
+ self.conv_after_body = nn.Sequential(
1158
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
1159
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
1160
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
1161
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
1162
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
1163
+ )
1164
+
1165
+ #####################################################################################################
1166
+ ################################ 3, high quality image reconstruction ################################
1167
+ if self.upsampler == "pixelshuffle":
1168
+ # for classical SR
1169
+ self.conv_before_upsample = nn.Sequential(
1170
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1171
+ )
1172
+ self.upsample = Upsample(upscale, num_feat)
1173
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1174
+ elif self.upsampler == "pixelshuffle_aux":
1175
+ self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
1176
+ self.conv_before_upsample = nn.Sequential(
1177
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1178
+ )
1179
+ self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1180
+ self.conv_after_aux = nn.Sequential(
1181
+ nn.Conv2d(3, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1182
+ )
1183
+ self.upsample = Upsample(upscale, num_feat)
1184
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1185
+
1186
+ elif self.upsampler == "pixelshuffle_hf":
1187
+ self.conv_before_upsample = nn.Sequential(
1188
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1189
+ )
1190
+ self.upsample = Upsample(upscale, num_feat)
1191
+ self.upsample_hf = Upsample_hf(upscale, num_feat)
1192
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1193
+ self.conv_first_hf = nn.Sequential(
1194
+ nn.Conv2d(num_feat, embed_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)
1195
+ )
1196
+ self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1197
+ self.conv_before_upsample_hf = nn.Sequential(
1198
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1199
+ )
1200
+ self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1201
+
1202
+ elif self.upsampler == "pixelshuffledirect":
1203
+ # for lightweight SR (to save parameters)
1204
+ self.upsample = UpsampleOneStep(
1205
+ upscale,
1206
+ embed_dim,
1207
+ num_out_ch,
1208
+ (patches_resolution[0], patches_resolution[1]),
1209
+ )
1210
+ elif self.upsampler == "nearest+conv":
1211
+ # for real-world SR (less artifacts)
1212
+ assert self.upscale == 4, "only support x4 now."
1213
+ self.conv_before_upsample = nn.Sequential(
1214
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1215
+ )
1216
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1217
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1218
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1219
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1220
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1221
+ else:
1222
+ # for image denoising and JPEG compression artifact reduction
1223
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
1224
+
1225
+ self.apply(self._init_weights)
1226
+
1227
+ self.load_state_dict(state_dict)
1228
+
1229
+ def _init_weights(self, m):
1230
+ if isinstance(m, nn.Linear):
1231
+ trunc_normal_(m.weight, std=0.02)
1232
+ if isinstance(m, nn.Linear) and m.bias is not None:
1233
+ nn.init.constant_(m.bias, 0)
1234
+ elif isinstance(m, nn.LayerNorm):
1235
+ nn.init.constant_(m.bias, 0)
1236
+ nn.init.constant_(m.weight, 1.0)
1237
+
1238
+ @torch.jit.ignore # type: ignore
1239
+ def no_weight_decay(self):
1240
+ return {"absolute_pos_embed"}
1241
+
1242
+ @torch.jit.ignore # type: ignore
1243
+ def no_weight_decay_keywords(self):
1244
+ return {"relative_position_bias_table"}
1245
+
1246
+ def check_image_size(self, x):
1247
+ _, _, h, w = x.size()
1248
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
1249
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
1250
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
1251
+ return x
1252
+
1253
+ def forward_features(self, x):
1254
+ x_size = (x.shape[2], x.shape[3])
1255
+ x = self.patch_embed(x)
1256
+ if self.ape:
1257
+ x = x + self.absolute_pos_embed
1258
+ x = self.pos_drop(x)
1259
+
1260
+ for layer in self.layers:
1261
+ x = layer(x, x_size)
1262
+
1263
+ x = self.norm(x) # B L C
1264
+ x = self.patch_unembed(x, x_size)
1265
+
1266
+ return x
1267
+
1268
+ def forward_features_hf(self, x):
1269
+ x_size = (x.shape[2], x.shape[3])
1270
+ x = self.patch_embed(x)
1271
+ if self.ape:
1272
+ x = x + self.absolute_pos_embed
1273
+ x = self.pos_drop(x)
1274
+
1275
+ for layer in self.layers_hf:
1276
+ x = layer(x, x_size)
1277
+
1278
+ x = self.norm(x) # B L C
1279
+ x = self.patch_unembed(x, x_size)
1280
+
1281
+ return x
1282
+
1283
+ def forward(self, x):
1284
+ H, W = x.shape[2:]
1285
+ x = self.check_image_size(x)
1286
+
1287
+ self.mean = self.mean.type_as(x)
1288
+ x = (x - self.mean) * self.img_range
1289
+
1290
+ if self.upsampler == "pixelshuffle":
1291
+ # for classical SR
1292
+ x = self.conv_first(x)
1293
+ x = self.conv_after_body(self.forward_features(x)) + x
1294
+ x = self.conv_before_upsample(x)
1295
+ x = self.conv_last(self.upsample(x))
1296
+ elif self.upsampler == "pixelshuffle_aux":
1297
+ bicubic = F.interpolate(
1298
+ x,
1299
+ size=(H * self.upscale, W * self.upscale),
1300
+ mode="bicubic",
1301
+ align_corners=False,
1302
+ )
1303
+ bicubic = self.conv_bicubic(bicubic)
1304
+ x = self.conv_first(x)
1305
+ x = self.conv_after_body(self.forward_features(x)) + x
1306
+ x = self.conv_before_upsample(x)
1307
+ aux = self.conv_aux(x) # b, 3, LR_H, LR_W
1308
+ x = self.conv_after_aux(aux)
1309
+ x = (
1310
+ self.upsample(x)[:, :, : H * self.upscale, : W * self.upscale]
1311
+ + bicubic[:, :, : H * self.upscale, : W * self.upscale]
1312
+ )
1313
+ x = self.conv_last(x)
1314
+ aux = aux / self.img_range + self.mean
1315
+ elif self.upsampler == "pixelshuffle_hf":
1316
+ # for classical SR with HF
1317
+ x = self.conv_first(x)
1318
+ x = self.conv_after_body(self.forward_features(x)) + x
1319
+ x_before = self.conv_before_upsample(x)
1320
+ x_out = self.conv_last(self.upsample(x_before))
1321
+
1322
+ x_hf = self.conv_first_hf(x_before)
1323
+ x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
1324
+ x_hf = self.conv_before_upsample_hf(x_hf)
1325
+ x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
1326
+ x = x_out + x_hf
1327
+ x_hf = x_hf / self.img_range + self.mean
1328
+
1329
+ elif self.upsampler == "pixelshuffledirect":
1330
+ # for lightweight SR
1331
+ x = self.conv_first(x)
1332
+ x = self.conv_after_body(self.forward_features(x)) + x
1333
+ x = self.upsample(x)
1334
+ elif self.upsampler == "nearest+conv":
1335
+ # for real-world SR
1336
+ x = self.conv_first(x)
1337
+ x = self.conv_after_body(self.forward_features(x)) + x
1338
+ x = self.conv_before_upsample(x)
1339
+ x = self.lrelu(
1340
+ self.conv_up1(
1341
+ torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
1342
+ )
1343
+ )
1344
+ x = self.lrelu(
1345
+ self.conv_up2(
1346
+ torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
1347
+ )
1348
+ )
1349
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
1350
+ else:
1351
+ # for image denoising and JPEG compression artifact reduction
1352
+ x_first = self.conv_first(x)
1353
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
1354
+ x = x + self.conv_last(res)
1355
+
1356
+ x = x / self.img_range + self.mean
1357
+ if self.upsampler == "pixelshuffle_aux":
1358
+ # NOTE: I removed an "aux" output here. not sure what that was for
1359
+ return x[:, :, : H * self.upscale, : W * self.upscale] # type: ignore
1360
+
1361
+ elif self.upsampler == "pixelshuffle_hf":
1362
+ x_out = x_out / self.img_range + self.mean # type: ignore
1363
+ return x_out[:, :, : H * self.upscale, : W * self.upscale], x[:, :, : H * self.upscale, : W * self.upscale], x_hf[:, :, : H * self.upscale, : W * self.upscale] # type: ignore
1364
+
1365
+ else:
1366
+ return x[:, :, : H * self.upscale, : W * self.upscale]
1367
+
1368
+ def flops(self):
1369
+ flops = 0
1370
+ H, W = self.patches_resolution
1371
+ flops += H * W * 3 * self.embed_dim * 9
1372
+ flops += self.patch_embed.flops()
1373
+ for i, layer in enumerate(self.layers):
1374
+ flops += layer.flops() # type: ignore
1375
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
1376
+ flops += self.upsample.flops() # type: ignore
1377
+ return flops
comfy_extras/chainner_models/architecture/SwinIR.py ADDED
@@ -0,0 +1,1208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # -----------------------------------------------------------------------------------
3
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
4
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
5
+ # -----------------------------------------------------------------------------------
6
+
7
+ import math
8
+ import re
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as checkpoint
14
+
15
+ # Originally from the timm package
16
+ from .timm.drop import DropPath
17
+ from .timm.helpers import to_2tuple
18
+ from .timm.weight_init import trunc_normal_
19
+
20
+
21
+ class Mlp(nn.Module):
22
+ def __init__(
23
+ self,
24
+ in_features,
25
+ hidden_features=None,
26
+ out_features=None,
27
+ act_layer=nn.GELU,
28
+ drop=0.0,
29
+ ):
30
+ super().__init__()
31
+ out_features = out_features or in_features
32
+ hidden_features = hidden_features or in_features
33
+ self.fc1 = nn.Linear(in_features, hidden_features)
34
+ self.act = act_layer()
35
+ self.fc2 = nn.Linear(hidden_features, out_features)
36
+ self.drop = nn.Dropout(drop)
37
+
38
+ def forward(self, x):
39
+ x = self.fc1(x)
40
+ x = self.act(x)
41
+ x = self.drop(x)
42
+ x = self.fc2(x)
43
+ x = self.drop(x)
44
+ return x
45
+
46
+
47
+ def window_partition(x, window_size):
48
+ """
49
+ Args:
50
+ x: (B, H, W, C)
51
+ window_size (int): window size
52
+
53
+ Returns:
54
+ windows: (num_windows*B, window_size, window_size, C)
55
+ """
56
+ B, H, W, C = x.shape
57
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
58
+ windows = (
59
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
60
+ )
61
+ return windows
62
+
63
+
64
+ def window_reverse(windows, window_size, H, W):
65
+ """
66
+ Args:
67
+ windows: (num_windows*B, window_size, window_size, C)
68
+ window_size (int): Window size
69
+ H (int): Height of image
70
+ W (int): Width of image
71
+
72
+ Returns:
73
+ x: (B, H, W, C)
74
+ """
75
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
76
+ x = windows.view(
77
+ B, H // window_size, W // window_size, window_size, window_size, -1
78
+ )
79
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
80
+ return x
81
+
82
+
83
+ class WindowAttention(nn.Module):
84
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
85
+ It supports both of shifted and non-shifted window.
86
+
87
+ Args:
88
+ dim (int): Number of input channels.
89
+ window_size (tuple[int]): The height and width of the window.
90
+ num_heads (int): Number of attention heads.
91
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
92
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
93
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
94
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ dim,
100
+ window_size,
101
+ num_heads,
102
+ qkv_bias=True,
103
+ qk_scale=None,
104
+ attn_drop=0.0,
105
+ proj_drop=0.0,
106
+ ):
107
+ super().__init__()
108
+ self.dim = dim
109
+ self.window_size = window_size # Wh, Ww
110
+ self.num_heads = num_heads
111
+ head_dim = dim // num_heads
112
+ self.scale = qk_scale or head_dim**-0.5
113
+
114
+ # define a parameter table of relative position bias
115
+ self.relative_position_bias_table = nn.Parameter( # type: ignore
116
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
117
+ ) # 2*Wh-1 * 2*Ww-1, nH
118
+
119
+ # get pair-wise relative position index for each token inside the window
120
+ coords_h = torch.arange(self.window_size[0])
121
+ coords_w = torch.arange(self.window_size[1])
122
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
123
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
124
+ relative_coords = (
125
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
126
+ ) # 2, Wh*Ww, Wh*Ww
127
+ relative_coords = relative_coords.permute(
128
+ 1, 2, 0
129
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
130
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
131
+ relative_coords[:, :, 1] += self.window_size[1] - 1
132
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
133
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
134
+ self.register_buffer("relative_position_index", relative_position_index)
135
+
136
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
137
+ self.attn_drop = nn.Dropout(attn_drop)
138
+ self.proj = nn.Linear(dim, dim)
139
+
140
+ self.proj_drop = nn.Dropout(proj_drop)
141
+
142
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
143
+ self.softmax = nn.Softmax(dim=-1)
144
+
145
+ def forward(self, x, mask=None):
146
+ """
147
+ Args:
148
+ x: input features with shape of (num_windows*B, N, C)
149
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
150
+ """
151
+ B_, N, C = x.shape
152
+ qkv = (
153
+ self.qkv(x)
154
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
155
+ .permute(2, 0, 3, 1, 4)
156
+ )
157
+ q, k, v = (
158
+ qkv[0],
159
+ qkv[1],
160
+ qkv[2],
161
+ ) # make torchscript happy (cannot use tensor as tuple)
162
+
163
+ q = q * self.scale
164
+ attn = q @ k.transpose(-2, -1)
165
+
166
+ relative_position_bias = self.relative_position_bias_table[
167
+ self.relative_position_index.view(-1) # type: ignore
168
+ ].view(
169
+ self.window_size[0] * self.window_size[1],
170
+ self.window_size[0] * self.window_size[1],
171
+ -1,
172
+ ) # Wh*Ww,Wh*Ww,nH
173
+ relative_position_bias = relative_position_bias.permute(
174
+ 2, 0, 1
175
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
176
+ attn = attn + relative_position_bias.unsqueeze(0)
177
+
178
+ if mask is not None:
179
+ nW = mask.shape[0]
180
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
181
+ 1
182
+ ).unsqueeze(0)
183
+ attn = attn.view(-1, self.num_heads, N, N)
184
+ attn = self.softmax(attn)
185
+ else:
186
+ attn = self.softmax(attn)
187
+
188
+ attn = self.attn_drop(attn)
189
+
190
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
191
+ x = self.proj(x)
192
+ x = self.proj_drop(x)
193
+ return x
194
+
195
+ def extra_repr(self) -> str:
196
+ return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
197
+
198
+ def flops(self, N):
199
+ # calculate flops for 1 window with token length of N
200
+ flops = 0
201
+ # qkv = self.qkv(x)
202
+ flops += N * self.dim * 3 * self.dim
203
+ # attn = (q @ k.transpose(-2, -1))
204
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
205
+ # x = (attn @ v)
206
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
207
+ # x = self.proj(x)
208
+ flops += N * self.dim * self.dim
209
+ return flops
210
+
211
+
212
+ class SwinTransformerBlock(nn.Module):
213
+ r"""Swin Transformer Block.
214
+
215
+ Args:
216
+ dim (int): Number of input channels.
217
+ input_resolution (tuple[int]): Input resulotion.
218
+ num_heads (int): Number of attention heads.
219
+ window_size (int): Window size.
220
+ shift_size (int): Shift size for SW-MSA.
221
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
222
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
223
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
224
+ drop (float, optional): Dropout rate. Default: 0.0
225
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
226
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
227
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
228
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
229
+ """
230
+
231
+ def __init__(
232
+ self,
233
+ dim,
234
+ input_resolution,
235
+ num_heads,
236
+ window_size=7,
237
+ shift_size=0,
238
+ mlp_ratio=4.0,
239
+ qkv_bias=True,
240
+ qk_scale=None,
241
+ drop=0.0,
242
+ attn_drop=0.0,
243
+ drop_path=0.0,
244
+ act_layer=nn.GELU,
245
+ norm_layer=nn.LayerNorm,
246
+ ):
247
+ super().__init__()
248
+ self.dim = dim
249
+ self.input_resolution = input_resolution
250
+ self.num_heads = num_heads
251
+ self.window_size = window_size
252
+ self.shift_size = shift_size
253
+ self.mlp_ratio = mlp_ratio
254
+ if min(self.input_resolution) <= self.window_size:
255
+ # if window size is larger than input resolution, we don't partition windows
256
+ self.shift_size = 0
257
+ self.window_size = min(self.input_resolution)
258
+ assert (
259
+ 0 <= self.shift_size < self.window_size
260
+ ), "shift_size must in 0-window_size"
261
+
262
+ self.norm1 = norm_layer(dim)
263
+ self.attn = WindowAttention(
264
+ dim,
265
+ window_size=to_2tuple(self.window_size),
266
+ num_heads=num_heads,
267
+ qkv_bias=qkv_bias,
268
+ qk_scale=qk_scale,
269
+ attn_drop=attn_drop,
270
+ proj_drop=drop,
271
+ )
272
+
273
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
274
+ self.norm2 = norm_layer(dim)
275
+ mlp_hidden_dim = int(dim * mlp_ratio)
276
+ self.mlp = Mlp(
277
+ in_features=dim,
278
+ hidden_features=mlp_hidden_dim,
279
+ act_layer=act_layer,
280
+ drop=drop,
281
+ )
282
+
283
+ if self.shift_size > 0:
284
+ attn_mask = self.calculate_mask(self.input_resolution)
285
+ else:
286
+ attn_mask = None
287
+
288
+ self.register_buffer("attn_mask", attn_mask)
289
+
290
+ def calculate_mask(self, x_size):
291
+ # calculate attention mask for SW-MSA
292
+ H, W = x_size
293
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
294
+ h_slices = (
295
+ slice(0, -self.window_size),
296
+ slice(-self.window_size, -self.shift_size),
297
+ slice(-self.shift_size, None),
298
+ )
299
+ w_slices = (
300
+ slice(0, -self.window_size),
301
+ slice(-self.window_size, -self.shift_size),
302
+ slice(-self.shift_size, None),
303
+ )
304
+ cnt = 0
305
+ for h in h_slices:
306
+ for w in w_slices:
307
+ img_mask[:, h, w, :] = cnt
308
+ cnt += 1
309
+
310
+ mask_windows = window_partition(
311
+ img_mask, self.window_size
312
+ ) # nW, window_size, window_size, 1
313
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
314
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
315
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
316
+ attn_mask == 0, float(0.0)
317
+ )
318
+
319
+ return attn_mask
320
+
321
+ def forward(self, x, x_size):
322
+ H, W = x_size
323
+ B, L, C = x.shape
324
+ # assert L == H * W, "input feature has wrong size"
325
+
326
+ shortcut = x
327
+ x = self.norm1(x)
328
+ x = x.view(B, H, W, C)
329
+
330
+ # cyclic shift
331
+ if self.shift_size > 0:
332
+ shifted_x = torch.roll(
333
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
334
+ )
335
+ else:
336
+ shifted_x = x
337
+
338
+ # partition windows
339
+ x_windows = window_partition(
340
+ shifted_x, self.window_size
341
+ ) # nW*B, window_size, window_size, C
342
+ x_windows = x_windows.view(
343
+ -1, self.window_size * self.window_size, C
344
+ ) # nW*B, window_size*window_size, C
345
+
346
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
347
+ if self.input_resolution == x_size:
348
+ attn_windows = self.attn(
349
+ x_windows, mask=self.attn_mask
350
+ ) # nW*B, window_size*window_size, C
351
+ else:
352
+ attn_windows = self.attn(
353
+ x_windows, mask=self.calculate_mask(x_size).to(x.device)
354
+ )
355
+
356
+ # merge windows
357
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
358
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
359
+
360
+ # reverse cyclic shift
361
+ if self.shift_size > 0:
362
+ x = torch.roll(
363
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
364
+ )
365
+ else:
366
+ x = shifted_x
367
+ x = x.view(B, H * W, C)
368
+
369
+ # FFN
370
+ x = shortcut + self.drop_path(x)
371
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
372
+
373
+ return x
374
+
375
+ def extra_repr(self) -> str:
376
+ return (
377
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
378
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
379
+ )
380
+
381
+ def flops(self):
382
+ flops = 0
383
+ H, W = self.input_resolution
384
+ # norm1
385
+ flops += self.dim * H * W
386
+ # W-MSA/SW-MSA
387
+ nW = H * W / self.window_size / self.window_size
388
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
389
+ # mlp
390
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
391
+ # norm2
392
+ flops += self.dim * H * W
393
+ return flops
394
+
395
+
396
+ class PatchMerging(nn.Module):
397
+ r"""Patch Merging Layer.
398
+
399
+ Args:
400
+ input_resolution (tuple[int]): Resolution of input feature.
401
+ dim (int): Number of input channels.
402
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
403
+ """
404
+
405
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
406
+ super().__init__()
407
+ self.input_resolution = input_resolution
408
+ self.dim = dim
409
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
410
+ self.norm = norm_layer(4 * dim)
411
+
412
+ def forward(self, x):
413
+ """
414
+ x: B, H*W, C
415
+ """
416
+ H, W = self.input_resolution
417
+ B, L, C = x.shape
418
+ assert L == H * W, "input feature has wrong size"
419
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
420
+
421
+ x = x.view(B, H, W, C)
422
+
423
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
424
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
425
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
426
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
427
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
428
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
429
+
430
+ x = self.norm(x)
431
+ x = self.reduction(x)
432
+
433
+ return x
434
+
435
+ def extra_repr(self) -> str:
436
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
437
+
438
+ def flops(self):
439
+ H, W = self.input_resolution
440
+ flops = H * W * self.dim
441
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
442
+ return flops
443
+
444
+
445
+ class BasicLayer(nn.Module):
446
+ """A basic Swin Transformer layer for one stage.
447
+
448
+ Args:
449
+ dim (int): Number of input channels.
450
+ input_resolution (tuple[int]): Input resolution.
451
+ depth (int): Number of blocks.
452
+ num_heads (int): Number of attention heads.
453
+ window_size (int): Local window size.
454
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
455
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
456
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
457
+ drop (float, optional): Dropout rate. Default: 0.0
458
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
459
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
460
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
461
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
462
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
463
+ """
464
+
465
+ def __init__(
466
+ self,
467
+ dim,
468
+ input_resolution,
469
+ depth,
470
+ num_heads,
471
+ window_size,
472
+ mlp_ratio=4.0,
473
+ qkv_bias=True,
474
+ qk_scale=None,
475
+ drop=0.0,
476
+ attn_drop=0.0,
477
+ drop_path=0.0,
478
+ norm_layer=nn.LayerNorm,
479
+ downsample=None,
480
+ use_checkpoint=False,
481
+ ):
482
+ super().__init__()
483
+ self.dim = dim
484
+ self.input_resolution = input_resolution
485
+ self.depth = depth
486
+ self.use_checkpoint = use_checkpoint
487
+
488
+ # build blocks
489
+ self.blocks = nn.ModuleList(
490
+ [
491
+ SwinTransformerBlock(
492
+ dim=dim,
493
+ input_resolution=input_resolution,
494
+ num_heads=num_heads,
495
+ window_size=window_size,
496
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
497
+ mlp_ratio=mlp_ratio,
498
+ qkv_bias=qkv_bias,
499
+ qk_scale=qk_scale,
500
+ drop=drop,
501
+ attn_drop=attn_drop,
502
+ drop_path=drop_path[i]
503
+ if isinstance(drop_path, list)
504
+ else drop_path,
505
+ norm_layer=norm_layer,
506
+ )
507
+ for i in range(depth)
508
+ ]
509
+ )
510
+
511
+ # patch merging layer
512
+ if downsample is not None:
513
+ self.downsample = downsample(
514
+ input_resolution, dim=dim, norm_layer=norm_layer
515
+ )
516
+ else:
517
+ self.downsample = None
518
+
519
+ def forward(self, x, x_size):
520
+ for blk in self.blocks:
521
+ if self.use_checkpoint:
522
+ x = checkpoint.checkpoint(blk, x, x_size)
523
+ else:
524
+ x = blk(x, x_size)
525
+ if self.downsample is not None:
526
+ x = self.downsample(x)
527
+ return x
528
+
529
+ def extra_repr(self) -> str:
530
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
531
+
532
+ def flops(self):
533
+ flops = 0
534
+ for blk in self.blocks:
535
+ flops += blk.flops() # type: ignore
536
+ if self.downsample is not None:
537
+ flops += self.downsample.flops()
538
+ return flops
539
+
540
+
541
+ class RSTB(nn.Module):
542
+ """Residual Swin Transformer Block (RSTB).
543
+
544
+ Args:
545
+ dim (int): Number of input channels.
546
+ input_resolution (tuple[int]): Input resolution.
547
+ depth (int): Number of blocks.
548
+ num_heads (int): Number of attention heads.
549
+ window_size (int): Local window size.
550
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
551
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
552
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
553
+ drop (float, optional): Dropout rate. Default: 0.0
554
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
555
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
556
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
557
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
558
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
559
+ img_size: Input image size.
560
+ patch_size: Patch size.
561
+ resi_connection: The convolutional block before residual connection.
562
+ """
563
+
564
+ def __init__(
565
+ self,
566
+ dim,
567
+ input_resolution,
568
+ depth,
569
+ num_heads,
570
+ window_size,
571
+ mlp_ratio=4.0,
572
+ qkv_bias=True,
573
+ qk_scale=None,
574
+ drop=0.0,
575
+ attn_drop=0.0,
576
+ drop_path=0.0,
577
+ norm_layer=nn.LayerNorm,
578
+ downsample=None,
579
+ use_checkpoint=False,
580
+ img_size=224,
581
+ patch_size=4,
582
+ resi_connection="1conv",
583
+ ):
584
+ super(RSTB, self).__init__()
585
+
586
+ self.dim = dim
587
+ self.input_resolution = input_resolution
588
+
589
+ self.residual_group = BasicLayer(
590
+ dim=dim,
591
+ input_resolution=input_resolution,
592
+ depth=depth,
593
+ num_heads=num_heads,
594
+ window_size=window_size,
595
+ mlp_ratio=mlp_ratio,
596
+ qkv_bias=qkv_bias,
597
+ qk_scale=qk_scale,
598
+ drop=drop,
599
+ attn_drop=attn_drop,
600
+ drop_path=drop_path,
601
+ norm_layer=norm_layer,
602
+ downsample=downsample,
603
+ use_checkpoint=use_checkpoint,
604
+ )
605
+
606
+ if resi_connection == "1conv":
607
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
608
+ elif resi_connection == "3conv":
609
+ # to save parameters and memory
610
+ self.conv = nn.Sequential(
611
+ nn.Conv2d(dim, dim // 4, 3, 1, 1),
612
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
613
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
614
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
615
+ nn.Conv2d(dim // 4, dim, 3, 1, 1),
616
+ )
617
+
618
+ self.patch_embed = PatchEmbed(
619
+ img_size=img_size,
620
+ patch_size=patch_size,
621
+ in_chans=0,
622
+ embed_dim=dim,
623
+ norm_layer=None,
624
+ )
625
+
626
+ self.patch_unembed = PatchUnEmbed(
627
+ img_size=img_size,
628
+ patch_size=patch_size,
629
+ in_chans=0,
630
+ embed_dim=dim,
631
+ norm_layer=None,
632
+ )
633
+
634
+ def forward(self, x, x_size):
635
+ return (
636
+ self.patch_embed(
637
+ self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
638
+ )
639
+ + x
640
+ )
641
+
642
+ def flops(self):
643
+ flops = 0
644
+ flops += self.residual_group.flops()
645
+ H, W = self.input_resolution
646
+ flops += H * W * self.dim * self.dim * 9
647
+ flops += self.patch_embed.flops()
648
+ flops += self.patch_unembed.flops()
649
+
650
+ return flops
651
+
652
+
653
+ class PatchEmbed(nn.Module):
654
+ r"""Image to Patch Embedding
655
+
656
+ Args:
657
+ img_size (int): Image size. Default: 224.
658
+ patch_size (int): Patch token size. Default: 4.
659
+ in_chans (int): Number of input image channels. Default: 3.
660
+ embed_dim (int): Number of linear projection output channels. Default: 96.
661
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
662
+ """
663
+
664
+ def __init__(
665
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
666
+ ):
667
+ super().__init__()
668
+ img_size = to_2tuple(img_size)
669
+ patch_size = to_2tuple(patch_size)
670
+ patches_resolution = [
671
+ img_size[0] // patch_size[0], # type: ignore
672
+ img_size[1] // patch_size[1], # type: ignore
673
+ ]
674
+ self.img_size = img_size
675
+ self.patch_size = patch_size
676
+ self.patches_resolution = patches_resolution
677
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
678
+
679
+ self.in_chans = in_chans
680
+ self.embed_dim = embed_dim
681
+
682
+ if norm_layer is not None:
683
+ self.norm = norm_layer(embed_dim)
684
+ else:
685
+ self.norm = None
686
+
687
+ def forward(self, x):
688
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
689
+ if self.norm is not None:
690
+ x = self.norm(x)
691
+ return x
692
+
693
+ def flops(self):
694
+ flops = 0
695
+ H, W = self.img_size
696
+ if self.norm is not None:
697
+ flops += H * W * self.embed_dim # type: ignore
698
+ return flops
699
+
700
+
701
+ class PatchUnEmbed(nn.Module):
702
+ r"""Image to Patch Unembedding
703
+
704
+ Args:
705
+ img_size (int): Image size. Default: 224.
706
+ patch_size (int): Patch token size. Default: 4.
707
+ in_chans (int): Number of input image channels. Default: 3.
708
+ embed_dim (int): Number of linear projection output channels. Default: 96.
709
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
710
+ """
711
+
712
+ def __init__(
713
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
714
+ ):
715
+ super().__init__()
716
+ img_size = to_2tuple(img_size)
717
+ patch_size = to_2tuple(patch_size)
718
+ patches_resolution = [
719
+ img_size[0] // patch_size[0], # type: ignore
720
+ img_size[1] // patch_size[1], # type: ignore
721
+ ]
722
+ self.img_size = img_size
723
+ self.patch_size = patch_size
724
+ self.patches_resolution = patches_resolution
725
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
726
+
727
+ self.in_chans = in_chans
728
+ self.embed_dim = embed_dim
729
+
730
+ def forward(self, x, x_size):
731
+ B, HW, C = x.shape
732
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
733
+ return x
734
+
735
+ def flops(self):
736
+ flops = 0
737
+ return flops
738
+
739
+
740
+ class Upsample(nn.Sequential):
741
+ """Upsample module.
742
+
743
+ Args:
744
+ scale (int): Scale factor. Supported scales: 2^n and 3.
745
+ num_feat (int): Channel number of intermediate features.
746
+ """
747
+
748
+ def __init__(self, scale, num_feat):
749
+ m = []
750
+ if (scale & (scale - 1)) == 0: # scale = 2^n
751
+ for _ in range(int(math.log(scale, 2))):
752
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
753
+ m.append(nn.PixelShuffle(2))
754
+ elif scale == 3:
755
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
756
+ m.append(nn.PixelShuffle(3))
757
+ else:
758
+ raise ValueError(
759
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
760
+ )
761
+ super(Upsample, self).__init__(*m)
762
+
763
+
764
+ class UpsampleOneStep(nn.Sequential):
765
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
766
+ Used in lightweight SR to save parameters.
767
+
768
+ Args:
769
+ scale (int): Scale factor. Supported scales: 2^n and 3.
770
+ num_feat (int): Channel number of intermediate features.
771
+
772
+ """
773
+
774
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
775
+ self.num_feat = num_feat
776
+ self.input_resolution = input_resolution
777
+ m = []
778
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
779
+ m.append(nn.PixelShuffle(scale))
780
+ super(UpsampleOneStep, self).__init__(*m)
781
+
782
+ def flops(self):
783
+ H, W = self.input_resolution # type: ignore
784
+ flops = H * W * self.num_feat * 3 * 9
785
+ return flops
786
+
787
+
788
+ class SwinIR(nn.Module):
789
+ r"""SwinIR
790
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
791
+
792
+ Args:
793
+ img_size (int | tuple(int)): Input image size. Default 64
794
+ patch_size (int | tuple(int)): Patch size. Default: 1
795
+ in_chans (int): Number of input image channels. Default: 3
796
+ embed_dim (int): Patch embedding dimension. Default: 96
797
+ depths (tuple(int)): Depth of each Swin Transformer layer.
798
+ num_heads (tuple(int)): Number of attention heads in different layers.
799
+ window_size (int): Window size. Default: 7
800
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
801
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
802
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
803
+ drop_rate (float): Dropout rate. Default: 0
804
+ attn_drop_rate (float): Attention dropout rate. Default: 0
805
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
806
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
807
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
808
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
809
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
810
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
811
+ img_range: Image range. 1. or 255.
812
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
813
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
814
+ """
815
+
816
+ def __init__(
817
+ self,
818
+ state_dict,
819
+ **kwargs,
820
+ ):
821
+ super(SwinIR, self).__init__()
822
+
823
+ # Defaults
824
+ img_size = 64
825
+ patch_size = 1
826
+ in_chans = 3
827
+ embed_dim = 96
828
+ depths = [6, 6, 6, 6]
829
+ num_heads = [6, 6, 6, 6]
830
+ window_size = 7
831
+ mlp_ratio = 4.0
832
+ qkv_bias = True
833
+ qk_scale = None
834
+ drop_rate = 0.0
835
+ attn_drop_rate = 0.0
836
+ drop_path_rate = 0.1
837
+ norm_layer = nn.LayerNorm
838
+ ape = False
839
+ patch_norm = True
840
+ use_checkpoint = False
841
+ upscale = 2
842
+ img_range = 1.0
843
+ upsampler = ""
844
+ resi_connection = "1conv"
845
+ num_feat = 64
846
+ num_in_ch = in_chans
847
+ num_out_ch = in_chans
848
+ supports_fp16 = True
849
+
850
+ self.model_arch = "SwinIR"
851
+ self.sub_type = "SR"
852
+ self.state = state_dict
853
+ if "params_ema" in self.state:
854
+ self.state = self.state["params_ema"]
855
+ elif "params" in self.state:
856
+ self.state = self.state["params"]
857
+
858
+ state_keys = self.state.keys()
859
+
860
+ if "conv_before_upsample.0.weight" in state_keys:
861
+ if "conv_up1.weight" in state_keys:
862
+ upsampler = "nearest+conv"
863
+ else:
864
+ upsampler = "pixelshuffle"
865
+ supports_fp16 = False
866
+ elif "upsample.0.weight" in state_keys:
867
+ upsampler = "pixelshuffledirect"
868
+ else:
869
+ upsampler = ""
870
+
871
+ num_feat = (
872
+ self.state.get("conv_before_upsample.0.weight", None).shape[1]
873
+ if self.state.get("conv_before_upsample.weight", None)
874
+ else 64
875
+ )
876
+
877
+ num_in_ch = self.state["conv_first.weight"].shape[1]
878
+ in_chans = num_in_ch
879
+ if "conv_last.weight" in state_keys:
880
+ num_out_ch = self.state["conv_last.weight"].shape[0]
881
+ else:
882
+ num_out_ch = num_in_ch
883
+
884
+ upscale = 1
885
+ if upsampler == "nearest+conv":
886
+ upsample_keys = [
887
+ x for x in state_keys if "conv_up" in x and "bias" not in x
888
+ ]
889
+
890
+ for upsample_key in upsample_keys:
891
+ upscale *= 2
892
+ elif upsampler == "pixelshuffle":
893
+ upsample_keys = [
894
+ x
895
+ for x in state_keys
896
+ if "upsample" in x and "conv" not in x and "bias" not in x
897
+ ]
898
+ for upsample_key in upsample_keys:
899
+ shape = self.state[upsample_key].shape[0]
900
+ upscale *= math.sqrt(shape // num_feat)
901
+ upscale = int(upscale)
902
+ elif upsampler == "pixelshuffledirect":
903
+ upscale = int(
904
+ math.sqrt(self.state["upsample.0.bias"].shape[0] // num_out_ch)
905
+ )
906
+
907
+ max_layer_num = 0
908
+ max_block_num = 0
909
+ for key in state_keys:
910
+ result = re.match(
911
+ r"layers.(\d*).residual_group.blocks.(\d*).norm1.weight", key
912
+ )
913
+ if result:
914
+ layer_num, block_num = result.groups()
915
+ max_layer_num = max(max_layer_num, int(layer_num))
916
+ max_block_num = max(max_block_num, int(block_num))
917
+
918
+ depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]
919
+
920
+ if (
921
+ "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
922
+ in state_keys
923
+ ):
924
+ num_heads_num = self.state[
925
+ "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
926
+ ].shape[-1]
927
+ num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
928
+ else:
929
+ num_heads = depths
930
+
931
+ embed_dim = self.state["conv_first.weight"].shape[0]
932
+
933
+ mlp_ratio = float(
934
+ self.state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0]
935
+ / embed_dim
936
+ )
937
+
938
+ # TODO: could actually count the layers, but this should do
939
+ if "layers.0.conv.4.weight" in state_keys:
940
+ resi_connection = "3conv"
941
+ else:
942
+ resi_connection = "1conv"
943
+
944
+ window_size = int(
945
+ math.sqrt(
946
+ self.state[
947
+ "layers.0.residual_group.blocks.0.attn.relative_position_index"
948
+ ].shape[0]
949
+ )
950
+ )
951
+
952
+ if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
953
+ img_size = int(
954
+ math.sqrt(
955
+ self.state["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
956
+ )
957
+ * window_size
958
+ )
959
+
960
+ # The JPEG models are the only ones with window-size 7, and they also use this range
961
+ img_range = 255.0 if window_size == 7 else 1.0
962
+
963
+ self.in_nc = num_in_ch
964
+ self.out_nc = num_out_ch
965
+ self.num_feat = num_feat
966
+ self.embed_dim = embed_dim
967
+ self.num_heads = num_heads
968
+ self.depths = depths
969
+ self.window_size = window_size
970
+ self.mlp_ratio = mlp_ratio
971
+ self.scale = upscale
972
+ self.upsampler = upsampler
973
+ self.img_size = img_size
974
+ self.img_range = img_range
975
+
976
+ self.supports_fp16 = False # Too much weirdness to support this at the moment
977
+ self.supports_bfp16 = True
978
+ self.min_size_restriction = 16
979
+
980
+ self.img_range = img_range
981
+ if in_chans == 3:
982
+ rgb_mean = (0.4488, 0.4371, 0.4040)
983
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
984
+ else:
985
+ self.mean = torch.zeros(1, 1, 1, 1)
986
+ self.upscale = upscale
987
+ self.upsampler = upsampler
988
+ self.window_size = window_size
989
+
990
+ #####################################################################################################
991
+ ################################### 1, shallow feature extraction ###################################
992
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
993
+
994
+ #####################################################################################################
995
+ ################################### 2, deep feature extraction ######################################
996
+ self.num_layers = len(depths)
997
+ self.embed_dim = embed_dim
998
+ self.ape = ape
999
+ self.patch_norm = patch_norm
1000
+ self.num_features = embed_dim
1001
+ self.mlp_ratio = mlp_ratio
1002
+
1003
+ # split image into non-overlapping patches
1004
+ self.patch_embed = PatchEmbed(
1005
+ img_size=img_size,
1006
+ patch_size=patch_size,
1007
+ in_chans=embed_dim,
1008
+ embed_dim=embed_dim,
1009
+ norm_layer=norm_layer if self.patch_norm else None,
1010
+ )
1011
+ num_patches = self.patch_embed.num_patches
1012
+ patches_resolution = self.patch_embed.patches_resolution
1013
+ self.patches_resolution = patches_resolution
1014
+
1015
+ # merge non-overlapping patches into image
1016
+ self.patch_unembed = PatchUnEmbed(
1017
+ img_size=img_size,
1018
+ patch_size=patch_size,
1019
+ in_chans=embed_dim,
1020
+ embed_dim=embed_dim,
1021
+ norm_layer=norm_layer if self.patch_norm else None,
1022
+ )
1023
+
1024
+ # absolute position embedding
1025
+ if self.ape:
1026
+ self.absolute_pos_embed = nn.Parameter( # type: ignore
1027
+ torch.zeros(1, num_patches, embed_dim)
1028
+ )
1029
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
1030
+
1031
+ self.pos_drop = nn.Dropout(p=drop_rate)
1032
+
1033
+ # stochastic depth
1034
+ dpr = [
1035
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
1036
+ ] # stochastic depth decay rule
1037
+
1038
+ # build Residual Swin Transformer blocks (RSTB)
1039
+ self.layers = nn.ModuleList()
1040
+ for i_layer in range(self.num_layers):
1041
+ layer = RSTB(
1042
+ dim=embed_dim,
1043
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
1044
+ depth=depths[i_layer],
1045
+ num_heads=num_heads[i_layer],
1046
+ window_size=window_size,
1047
+ mlp_ratio=self.mlp_ratio,
1048
+ qkv_bias=qkv_bias,
1049
+ qk_scale=qk_scale,
1050
+ drop=drop_rate,
1051
+ attn_drop=attn_drop_rate,
1052
+ drop_path=dpr[
1053
+ sum(depths[:i_layer]) : sum(depths[: i_layer + 1]) # type: ignore
1054
+ ], # no impact on SR results
1055
+ norm_layer=norm_layer,
1056
+ downsample=None,
1057
+ use_checkpoint=use_checkpoint,
1058
+ img_size=img_size,
1059
+ patch_size=patch_size,
1060
+ resi_connection=resi_connection,
1061
+ )
1062
+ self.layers.append(layer)
1063
+ self.norm = norm_layer(self.num_features)
1064
+
1065
+ # build the last conv layer in deep feature extraction
1066
+ if resi_connection == "1conv":
1067
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1068
+ elif resi_connection == "3conv":
1069
+ # to save parameters and memory
1070
+ self.conv_after_body = nn.Sequential(
1071
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
1072
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
1073
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
1074
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
1075
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
1076
+ )
1077
+
1078
+ #####################################################################################################
1079
+ ################################ 3, high quality image reconstruction ################################
1080
+ if self.upsampler == "pixelshuffle":
1081
+ # for classical SR
1082
+ self.conv_before_upsample = nn.Sequential(
1083
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1084
+ )
1085
+ self.upsample = Upsample(upscale, num_feat)
1086
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1087
+ elif self.upsampler == "pixelshuffledirect":
1088
+ # for lightweight SR (to save parameters)
1089
+ self.upsample = UpsampleOneStep(
1090
+ upscale,
1091
+ embed_dim,
1092
+ num_out_ch,
1093
+ (patches_resolution[0], patches_resolution[1]),
1094
+ )
1095
+ elif self.upsampler == "nearest+conv":
1096
+ # for real-world SR (less artifacts)
1097
+ self.conv_before_upsample = nn.Sequential(
1098
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1099
+ )
1100
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1101
+ if self.upscale == 4:
1102
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1103
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1104
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1105
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1106
+ else:
1107
+ # for image denoising and JPEG compression artifact reduction
1108
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
1109
+
1110
+ self.apply(self._init_weights)
1111
+ self.load_state_dict(self.state, strict=False)
1112
+
1113
+ def _init_weights(self, m):
1114
+ if isinstance(m, nn.Linear):
1115
+ trunc_normal_(m.weight, std=0.02)
1116
+ if isinstance(m, nn.Linear) and m.bias is not None:
1117
+ nn.init.constant_(m.bias, 0)
1118
+ elif isinstance(m, nn.LayerNorm):
1119
+ nn.init.constant_(m.bias, 0)
1120
+ nn.init.constant_(m.weight, 1.0)
1121
+
1122
+ @torch.jit.ignore # type: ignore
1123
+ def no_weight_decay(self):
1124
+ return {"absolute_pos_embed"}
1125
+
1126
+ @torch.jit.ignore # type: ignore
1127
+ def no_weight_decay_keywords(self):
1128
+ return {"relative_position_bias_table"}
1129
+
1130
+ def check_image_size(self, x):
1131
+ _, _, h, w = x.size()
1132
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
1133
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
1134
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
1135
+ return x
1136
+
1137
+ def forward_features(self, x):
1138
+ x_size = (x.shape[2], x.shape[3])
1139
+ x = self.patch_embed(x)
1140
+ if self.ape:
1141
+ x = x + self.absolute_pos_embed
1142
+ x = self.pos_drop(x)
1143
+
1144
+ for layer in self.layers:
1145
+ x = layer(x, x_size)
1146
+
1147
+ x = self.norm(x) # B L C
1148
+ x = self.patch_unembed(x, x_size)
1149
+
1150
+ return x
1151
+
1152
+ def forward(self, x):
1153
+ H, W = x.shape[2:]
1154
+ x = self.check_image_size(x)
1155
+
1156
+ self.mean = self.mean.type_as(x)
1157
+ x = (x - self.mean) * self.img_range
1158
+
1159
+ if self.upsampler == "pixelshuffle":
1160
+ # for classical SR
1161
+ x = self.conv_first(x)
1162
+ x = self.conv_after_body(self.forward_features(x)) + x
1163
+ x = self.conv_before_upsample(x)
1164
+ x = self.conv_last(self.upsample(x))
1165
+ elif self.upsampler == "pixelshuffledirect":
1166
+ # for lightweight SR
1167
+ x = self.conv_first(x)
1168
+ x = self.conv_after_body(self.forward_features(x)) + x
1169
+ x = self.upsample(x)
1170
+ elif self.upsampler == "nearest+conv":
1171
+ # for real-world SR
1172
+ x = self.conv_first(x)
1173
+ x = self.conv_after_body(self.forward_features(x)) + x
1174
+ x = self.conv_before_upsample(x)
1175
+ x = self.lrelu(
1176
+ self.conv_up1(
1177
+ torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") # type: ignore
1178
+ )
1179
+ )
1180
+ if self.upscale == 4:
1181
+ x = self.lrelu(
1182
+ self.conv_up2(
1183
+ torch.nn.functional.interpolate( # type: ignore
1184
+ x, scale_factor=2, mode="nearest"
1185
+ )
1186
+ )
1187
+ )
1188
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
1189
+ else:
1190
+ # for image denoising and JPEG compression artifact reduction
1191
+ x_first = self.conv_first(x)
1192
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
1193
+ x = x + self.conv_last(res)
1194
+
1195
+ x = x / self.img_range + self.mean
1196
+
1197
+ return x[:, :, : H * self.upscale, : W * self.upscale]
1198
+
1199
+ def flops(self):
1200
+ flops = 0
1201
+ H, W = self.patches_resolution
1202
+ flops += H * W * 3 * self.embed_dim * 9
1203
+ flops += self.patch_embed.flops()
1204
+ for i, layer in enumerate(self.layers):
1205
+ flops += layer.flops() # type: ignore
1206
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
1207
+ flops += self.upsample.flops() # type: ignore
1208
+ return flops
comfy_extras/chainner_models/architecture/__init__.py ADDED
File without changes
comfy_extras/chainner_models/architecture/block.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from __future__ import annotations
5
+
6
+ from collections import OrderedDict
7
+ from typing import Literal
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ ####################
13
+ # Basic blocks
14
+ ####################
15
+
16
+
17
+ def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1):
18
+ # helper selecting activation
19
+ # neg_slope: for leakyrelu and init of prelu
20
+ # n_prelu: for p_relu num_parameters
21
+ act_type = act_type.lower()
22
+ if act_type == "relu":
23
+ layer = nn.ReLU(inplace)
24
+ elif act_type == "leakyrelu":
25
+ layer = nn.LeakyReLU(neg_slope, inplace)
26
+ elif act_type == "prelu":
27
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
28
+ else:
29
+ raise NotImplementedError(
30
+ "activation layer [{:s}] is not found".format(act_type)
31
+ )
32
+ return layer
33
+
34
+
35
+ def norm(norm_type: str, nc: int):
36
+ # helper selecting normalization layer
37
+ norm_type = norm_type.lower()
38
+ if norm_type == "batch":
39
+ layer = nn.BatchNorm2d(nc, affine=True)
40
+ elif norm_type == "instance":
41
+ layer = nn.InstanceNorm2d(nc, affine=False)
42
+ else:
43
+ raise NotImplementedError(
44
+ "normalization layer [{:s}] is not found".format(norm_type)
45
+ )
46
+ return layer
47
+
48
+
49
+ def pad(pad_type: str, padding):
50
+ # helper selecting padding layer
51
+ # if padding is 'zero', do by conv layers
52
+ pad_type = pad_type.lower()
53
+ if padding == 0:
54
+ return None
55
+ if pad_type == "reflect":
56
+ layer = nn.ReflectionPad2d(padding)
57
+ elif pad_type == "replicate":
58
+ layer = nn.ReplicationPad2d(padding)
59
+ else:
60
+ raise NotImplementedError(
61
+ "padding layer [{:s}] is not implemented".format(pad_type)
62
+ )
63
+ return layer
64
+
65
+
66
+ def get_valid_padding(kernel_size, dilation):
67
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
68
+ padding = (kernel_size - 1) // 2
69
+ return padding
70
+
71
+
72
+ class ConcatBlock(nn.Module):
73
+ # Concat the output of a submodule to its input
74
+ def __init__(self, submodule):
75
+ super(ConcatBlock, self).__init__()
76
+ self.sub = submodule
77
+
78
+ def forward(self, x):
79
+ output = torch.cat((x, self.sub(x)), dim=1)
80
+ return output
81
+
82
+ def __repr__(self):
83
+ tmpstr = "Identity .. \n|"
84
+ modstr = self.sub.__repr__().replace("\n", "\n|")
85
+ tmpstr = tmpstr + modstr
86
+ return tmpstr
87
+
88
+
89
+ class ShortcutBlock(nn.Module):
90
+ # Elementwise sum the output of a submodule to its input
91
+ def __init__(self, submodule):
92
+ super(ShortcutBlock, self).__init__()
93
+ self.sub = submodule
94
+
95
+ def forward(self, x):
96
+ output = x + self.sub(x)
97
+ return output
98
+
99
+ def __repr__(self):
100
+ tmpstr = "Identity + \n|"
101
+ modstr = self.sub.__repr__().replace("\n", "\n|")
102
+ tmpstr = tmpstr + modstr
103
+ return tmpstr
104
+
105
+
106
+ class ShortcutBlockSPSR(nn.Module):
107
+ # Elementwise sum the output of a submodule to its input
108
+ def __init__(self, submodule):
109
+ super(ShortcutBlockSPSR, self).__init__()
110
+ self.sub = submodule
111
+
112
+ def forward(self, x):
113
+ return x, self.sub
114
+
115
+ def __repr__(self):
116
+ tmpstr = "Identity + \n|"
117
+ modstr = self.sub.__repr__().replace("\n", "\n|")
118
+ tmpstr = tmpstr + modstr
119
+ return tmpstr
120
+
121
+
122
+ def sequential(*args):
123
+ # Flatten Sequential. It unwraps nn.Sequential.
124
+ if len(args) == 1:
125
+ if isinstance(args[0], OrderedDict):
126
+ raise NotImplementedError("sequential does not support OrderedDict input.")
127
+ return args[0] # No sequential is needed.
128
+ modules = []
129
+ for module in args:
130
+ if isinstance(module, nn.Sequential):
131
+ for submodule in module.children():
132
+ modules.append(submodule)
133
+ elif isinstance(module, nn.Module):
134
+ modules.append(module)
135
+ return nn.Sequential(*modules)
136
+
137
+
138
+ ConvMode = Literal["CNA", "NAC", "CNAC"]
139
+
140
+
141
+ def conv_block(
142
+ in_nc: int,
143
+ out_nc: int,
144
+ kernel_size,
145
+ stride=1,
146
+ dilation=1,
147
+ groups=1,
148
+ bias=True,
149
+ pad_type="zero",
150
+ norm_type: str | None = None,
151
+ act_type: str | None = "relu",
152
+ mode: ConvMode = "CNA",
153
+ ):
154
+ """
155
+ Conv layer with padding, normalization, activation
156
+ mode: CNA --> Conv -> Norm -> Act
157
+ NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
158
+ """
159
+ assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
160
+ padding = get_valid_padding(kernel_size, dilation)
161
+ p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
162
+ padding = padding if pad_type == "zero" else 0
163
+
164
+ c = nn.Conv2d(
165
+ in_nc,
166
+ out_nc,
167
+ kernel_size=kernel_size,
168
+ stride=stride,
169
+ padding=padding,
170
+ dilation=dilation,
171
+ bias=bias,
172
+ groups=groups,
173
+ )
174
+ a = act(act_type) if act_type else None
175
+ if mode in ("CNA", "CNAC"):
176
+ n = norm(norm_type, out_nc) if norm_type else None
177
+ return sequential(p, c, n, a)
178
+ elif mode == "NAC":
179
+ if norm_type is None and act_type is not None:
180
+ a = act(act_type, inplace=False)
181
+ # Important!
182
+ # input----ReLU(inplace)----Conv--+----output
183
+ # |________________________|
184
+ # inplace ReLU will modify the input, therefore wrong output
185
+ n = norm(norm_type, in_nc) if norm_type else None
186
+ return sequential(n, a, p, c)
187
+ else:
188
+ assert False, f"Invalid conv mode {mode}"
189
+
190
+
191
+ ####################
192
+ # Useful blocks
193
+ ####################
194
+
195
+
196
+ class ResNetBlock(nn.Module):
197
+ """
198
+ ResNet Block, 3-3 style
199
+ with extra residual scaling used in EDSR
200
+ (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
201
+ """
202
+
203
+ def __init__(
204
+ self,
205
+ in_nc,
206
+ mid_nc,
207
+ out_nc,
208
+ kernel_size=3,
209
+ stride=1,
210
+ dilation=1,
211
+ groups=1,
212
+ bias=True,
213
+ pad_type="zero",
214
+ norm_type=None,
215
+ act_type="relu",
216
+ mode: ConvMode = "CNA",
217
+ res_scale=1,
218
+ ):
219
+ super(ResNetBlock, self).__init__()
220
+ conv0 = conv_block(
221
+ in_nc,
222
+ mid_nc,
223
+ kernel_size,
224
+ stride,
225
+ dilation,
226
+ groups,
227
+ bias,
228
+ pad_type,
229
+ norm_type,
230
+ act_type,
231
+ mode,
232
+ )
233
+ if mode == "CNA":
234
+ act_type = None
235
+ if mode == "CNAC": # Residual path: |-CNAC-|
236
+ act_type = None
237
+ norm_type = None
238
+ conv1 = conv_block(
239
+ mid_nc,
240
+ out_nc,
241
+ kernel_size,
242
+ stride,
243
+ dilation,
244
+ groups,
245
+ bias,
246
+ pad_type,
247
+ norm_type,
248
+ act_type,
249
+ mode,
250
+ )
251
+ # if in_nc != out_nc:
252
+ # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
253
+ # None, None)
254
+ # print('Need a projecter in ResNetBlock.')
255
+ # else:
256
+ # self.project = lambda x:x
257
+ self.res = sequential(conv0, conv1)
258
+ self.res_scale = res_scale
259
+
260
+ def forward(self, x):
261
+ res = self.res(x).mul(self.res_scale)
262
+ return x + res
263
+
264
+
265
+ class RRDB(nn.Module):
266
+ """
267
+ Residual in Residual Dense Block
268
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
269
+ """
270
+
271
+ def __init__(
272
+ self,
273
+ nf,
274
+ kernel_size=3,
275
+ gc=32,
276
+ stride=1,
277
+ bias: bool = True,
278
+ pad_type="zero",
279
+ norm_type=None,
280
+ act_type="leakyrelu",
281
+ mode: ConvMode = "CNA",
282
+ _convtype="Conv2D",
283
+ _spectral_norm=False,
284
+ plus=False,
285
+ ):
286
+ super(RRDB, self).__init__()
287
+ self.RDB1 = ResidualDenseBlock_5C(
288
+ nf,
289
+ kernel_size,
290
+ gc,
291
+ stride,
292
+ bias,
293
+ pad_type,
294
+ norm_type,
295
+ act_type,
296
+ mode,
297
+ plus=plus,
298
+ )
299
+ self.RDB2 = ResidualDenseBlock_5C(
300
+ nf,
301
+ kernel_size,
302
+ gc,
303
+ stride,
304
+ bias,
305
+ pad_type,
306
+ norm_type,
307
+ act_type,
308
+ mode,
309
+ plus=plus,
310
+ )
311
+ self.RDB3 = ResidualDenseBlock_5C(
312
+ nf,
313
+ kernel_size,
314
+ gc,
315
+ stride,
316
+ bias,
317
+ pad_type,
318
+ norm_type,
319
+ act_type,
320
+ mode,
321
+ plus=plus,
322
+ )
323
+
324
+ def forward(self, x):
325
+ out = self.RDB1(x)
326
+ out = self.RDB2(out)
327
+ out = self.RDB3(out)
328
+ return out * 0.2 + x
329
+
330
+
331
+ class ResidualDenseBlock_5C(nn.Module):
332
+ """
333
+ Residual Dense Block
334
+ style: 5 convs
335
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
336
+ Modified options that can be used:
337
+ - "Partial Convolution based Padding" arXiv:1811.11718
338
+ - "Spectral normalization" arXiv:1802.05957
339
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
340
+ {Rakotonirina} and A. {Rasoanaivo}
341
+
342
+ Args:
343
+ nf (int): Channel number of intermediate features (num_feat).
344
+ gc (int): Channels for each growth (num_grow_ch: growth channel,
345
+ i.e. intermediate channels).
346
+ convtype (str): the type of convolution to use. Default: 'Conv2D'
347
+ gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new
348
+ trainable parameters)
349
+ plus (bool): enable the additional residual paths from ESRGAN+
350
+ (adds trainable parameters)
351
+ """
352
+
353
+ def __init__(
354
+ self,
355
+ nf=64,
356
+ kernel_size=3,
357
+ gc=32,
358
+ stride=1,
359
+ bias: bool = True,
360
+ pad_type="zero",
361
+ norm_type=None,
362
+ act_type="leakyrelu",
363
+ mode: ConvMode = "CNA",
364
+ plus=False,
365
+ ):
366
+ super(ResidualDenseBlock_5C, self).__init__()
367
+
368
+ ## +
369
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
370
+ ## +
371
+
372
+ self.conv1 = conv_block(
373
+ nf,
374
+ gc,
375
+ kernel_size,
376
+ stride,
377
+ bias=bias,
378
+ pad_type=pad_type,
379
+ norm_type=norm_type,
380
+ act_type=act_type,
381
+ mode=mode,
382
+ )
383
+ self.conv2 = conv_block(
384
+ nf + gc,
385
+ gc,
386
+ kernel_size,
387
+ stride,
388
+ bias=bias,
389
+ pad_type=pad_type,
390
+ norm_type=norm_type,
391
+ act_type=act_type,
392
+ mode=mode,
393
+ )
394
+ self.conv3 = conv_block(
395
+ nf + 2 * gc,
396
+ gc,
397
+ kernel_size,
398
+ stride,
399
+ bias=bias,
400
+ pad_type=pad_type,
401
+ norm_type=norm_type,
402
+ act_type=act_type,
403
+ mode=mode,
404
+ )
405
+ self.conv4 = conv_block(
406
+ nf + 3 * gc,
407
+ gc,
408
+ kernel_size,
409
+ stride,
410
+ bias=bias,
411
+ pad_type=pad_type,
412
+ norm_type=norm_type,
413
+ act_type=act_type,
414
+ mode=mode,
415
+ )
416
+ if mode == "CNA":
417
+ last_act = None
418
+ else:
419
+ last_act = act_type
420
+ self.conv5 = conv_block(
421
+ nf + 4 * gc,
422
+ nf,
423
+ 3,
424
+ stride,
425
+ bias=bias,
426
+ pad_type=pad_type,
427
+ norm_type=norm_type,
428
+ act_type=last_act,
429
+ mode=mode,
430
+ )
431
+
432
+ def forward(self, x):
433
+ x1 = self.conv1(x)
434
+ x2 = self.conv2(torch.cat((x, x1), 1))
435
+ if self.conv1x1:
436
+ # pylint: disable=not-callable
437
+ x2 = x2 + self.conv1x1(x) # +
438
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
439
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
440
+ if self.conv1x1:
441
+ x4 = x4 + x2 # +
442
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
443
+ return x5 * 0.2 + x
444
+
445
+
446
+ def conv1x1(in_planes, out_planes, stride=1):
447
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
448
+
449
+
450
+ ####################
451
+ # Upsampler
452
+ ####################
453
+
454
+
455
+ def pixelshuffle_block(
456
+ in_nc: int,
457
+ out_nc: int,
458
+ upscale_factor=2,
459
+ kernel_size=3,
460
+ stride=1,
461
+ bias=True,
462
+ pad_type="zero",
463
+ norm_type: str | None = None,
464
+ act_type="relu",
465
+ ):
466
+ """
467
+ Pixel shuffle layer
468
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
469
+ Neural Network, CVPR17)
470
+ """
471
+ conv = conv_block(
472
+ in_nc,
473
+ out_nc * (upscale_factor**2),
474
+ kernel_size,
475
+ stride,
476
+ bias=bias,
477
+ pad_type=pad_type,
478
+ norm_type=None,
479
+ act_type=None,
480
+ )
481
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
482
+
483
+ n = norm(norm_type, out_nc) if norm_type else None
484
+ a = act(act_type) if act_type else None
485
+ return sequential(conv, pixel_shuffle, n, a)
486
+
487
+
488
+ def upconv_block(
489
+ in_nc: int,
490
+ out_nc: int,
491
+ upscale_factor=2,
492
+ kernel_size=3,
493
+ stride=1,
494
+ bias=True,
495
+ pad_type="zero",
496
+ norm_type: str | None = None,
497
+ act_type="relu",
498
+ mode="nearest",
499
+ ):
500
+ # Up conv
501
+ # described in https://distill.pub/2016/deconv-checkerboard/
502
+ upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
503
+ conv = conv_block(
504
+ in_nc,
505
+ out_nc,
506
+ kernel_size,
507
+ stride,
508
+ bias=bias,
509
+ pad_type=pad_type,
510
+ norm_type=norm_type,
511
+ act_type=act_type,
512
+ )
513
+ return sequential(upsample, conv)
comfy_extras/chainner_models/architecture/face/LICENSE-GFPGAN ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the open source community by making GFPGAN available.
2
+
3
+ Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
+
5
+ GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below.
6
+
7
+
8
+ Terms of the Apache License Version 2.0:
9
+ ---------------------------------------------
10
+ Apache License
11
+
12
+ Version 2.0, January 2004
13
+
14
+ http://www.apache.org/licenses/
15
+
16
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
17
+ 1. Definitions.
18
+
19
+ “License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
20
+
21
+ “Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
22
+
23
+ “Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ “You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
26
+
27
+ “Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
28
+
29
+ “Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
30
+
31
+ “Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
32
+
33
+ “Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
34
+
35
+ “Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
36
+
37
+ “Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
38
+
39
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
40
+
41
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
42
+
43
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
44
+
45
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
46
+
47
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
48
+
49
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
50
+
51
+ If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
52
+
53
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
54
+
55
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
56
+
57
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
58
+
59
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
60
+
61
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
62
+
63
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
64
+
65
+ END OF TERMS AND CONDITIONS
66
+
67
+
68
+
69
+ Other dependencies and licenses:
70
+
71
+
72
+ Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein:
73
+ ---------------------------------------------
74
+ 1. basicsr
75
+ Copyright 2018-2020 BasicSR Authors
76
+
77
+
78
+ This BasicSR project is released under the Apache 2.0 license.
79
+
80
+ A copy of Apache 2.0 is included in this file.
81
+
82
+ StyleGAN2
83
+ The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch.
84
+ The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license.
85
+ DFDNet
86
+ The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
87
+
88
+ Terms of the Nvidia License:
89
+ ---------------------------------------------
90
+
91
+ 1. Definitions
92
+
93
+ "Licensor" means any person or entity that distributes its Work.
94
+
95
+ "Software" means the original work of authorship made available under
96
+ this License.
97
+
98
+ "Work" means the Software and any additions to or derivative works of
99
+ the Software that are made available under this License.
100
+
101
+ "Nvidia Processors" means any central processing unit (CPU), graphics
102
+ processing unit (GPU), field-programmable gate array (FPGA),
103
+ application-specific integrated circuit (ASIC) or any combination
104
+ thereof designed, made, sold, or provided by Nvidia or its affiliates.
105
+
106
+ The terms "reproduce," "reproduction," "derivative works," and
107
+ "distribution" have the meaning as provided under U.S. copyright law;
108
+ provided, however, that for the purposes of this License, derivative
109
+ works shall not include works that remain separable from, or merely
110
+ link (or bind by name) to the interfaces of, the Work.
111
+
112
+ Works, including the Software, are "made available" under this License
113
+ by including in or with the Work either (a) a copyright notice
114
+ referencing the applicability of this License to the Work, or (b) a
115
+ copy of this License.
116
+
117
+ 2. License Grants
118
+
119
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
120
+ License, each Licensor grants to you a perpetual, worldwide,
121
+ non-exclusive, royalty-free, copyright license to reproduce,
122
+ prepare derivative works of, publicly display, publicly perform,
123
+ sublicense and distribute its Work and any resulting derivative
124
+ works in any form.
125
+
126
+ 3. Limitations
127
+
128
+ 3.1 Redistribution. You may reproduce or distribute the Work only
129
+ if (a) you do so under this License, (b) you include a complete
130
+ copy of this License with your distribution, and (c) you retain
131
+ without modification any copyright, patent, trademark, or
132
+ attribution notices that are present in the Work.
133
+
134
+ 3.2 Derivative Works. You may specify that additional or different
135
+ terms apply to the use, reproduction, and distribution of your
136
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
137
+ provide that the use limitation in Section 3.3 applies to your
138
+ derivative works, and (b) you identify the specific derivative
139
+ works that are subject to Your Terms. Notwithstanding Your Terms,
140
+ this License (including the redistribution requirements in Section
141
+ 3.1) will continue to apply to the Work itself.
142
+
143
+ 3.3 Use Limitation. The Work and any derivative works thereof only
144
+ may be used or intended for use non-commercially. The Work or
145
+ derivative works thereof may be used or intended for use by Nvidia
146
+ or its affiliates commercially or non-commercially. As used herein,
147
+ "non-commercially" means for research or evaluation purposes only.
148
+
149
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
150
+ against any Licensor (including any claim, cross-claim or
151
+ counterclaim in a lawsuit) to enforce any patents that you allege
152
+ are infringed by any Work, then your rights under this License from
153
+ such Licensor (including the grants in Sections 2.1 and 2.2) will
154
+ terminate immediately.
155
+
156
+ 3.5 Trademarks. This License does not grant any rights to use any
157
+ Licensor's or its affiliates' names, logos, or trademarks, except
158
+ as necessary to reproduce the notices described in this License.
159
+
160
+ 3.6 Termination. If you violate any term of this License, then your
161
+ rights under this License (including the grants in Sections 2.1 and
162
+ 2.2) will terminate immediately.
163
+
164
+ 4. Disclaimer of Warranty.
165
+
166
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
167
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
168
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
169
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
170
+ THIS LICENSE.
171
+
172
+ 5. Limitation of Liability.
173
+
174
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
175
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
176
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
177
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
178
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
179
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
180
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
181
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
182
+ THE POSSIBILITY OF SUCH DAMAGES.
183
+
184
+ MIT License
185
+
186
+ Copyright (c) 2019 Kim Seonghyeon
187
+
188
+ Permission is hereby granted, free of charge, to any person obtaining a copy
189
+ of this software and associated documentation files (the "Software"), to deal
190
+ in the Software without restriction, including without limitation the rights
191
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
192
+ copies of the Software, and to permit persons to whom the Software is
193
+ furnished to do so, subject to the following conditions:
194
+
195
+ The above copyright notice and this permission notice shall be included in all
196
+ copies or substantial portions of the Software.
197
+
198
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
199
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
200
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
201
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
202
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
203
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
204
+ SOFTWARE.
205
+
206
+
207
+
208
+ Open Source Software licensed under the BSD 3-Clause license:
209
+ ---------------------------------------------
210
+ 1. torchvision
211
+ Copyright (c) Soumith Chintala 2016,
212
+ All rights reserved.
213
+
214
+ 2. torch
215
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
216
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
217
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
218
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
219
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
220
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
221
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
222
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
223
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
224
+
225
+
226
+ Terms of the BSD 3-Clause License:
227
+ ---------------------------------------------
228
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
229
+
230
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
231
+
232
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
233
+
234
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
235
+
236
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
237
+
238
+
239
+
240
+ Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
241
+ ---------------------------------------------
242
+ 1. numpy
243
+ Copyright (c) 2005-2020, NumPy Developers.
244
+ All rights reserved.
245
+
246
+ A copy of BSD 3-Clause License is included in this file.
247
+
248
+ The NumPy repository and source distributions bundle several libraries that are
249
+ compatibly licensed. We list these here.
250
+
251
+ Name: Numpydoc
252
+ Files: doc/sphinxext/numpydoc/*
253
+ License: BSD-2-Clause
254
+ For details, see doc/sphinxext/LICENSE.txt
255
+
256
+ Name: scipy-sphinx-theme
257
+ Files: doc/scipy-sphinx-theme/*
258
+ License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0
259
+ For details, see doc/scipy-sphinx-theme/LICENSE.txt
260
+
261
+ Name: lapack-lite
262
+ Files: numpy/linalg/lapack_lite/*
263
+ License: BSD-3-Clause
264
+ For details, see numpy/linalg/lapack_lite/LICENSE.txt
265
+
266
+ Name: tempita
267
+ Files: tools/npy_tempita/*
268
+ License: MIT
269
+ For details, see tools/npy_tempita/license.txt
270
+
271
+ Name: dragon4
272
+ Files: numpy/core/src/multiarray/dragon4.c
273
+ License: MIT
274
+ For license text, see numpy/core/src/multiarray/dragon4.c
275
+
276
+
277
+
278
+ Open Source Software licensed under the MIT license:
279
+ ---------------------------------------------
280
+ 1. facexlib
281
+ Copyright (c) 2020 Xintao Wang
282
+
283
+ 2. opencv-python
284
+ Copyright (c) Olli-Pekka Heinisuo
285
+ Please note that only files in cv2 package are used.
286
+
287
+
288
+ Terms of the MIT License:
289
+ ---------------------------------------------
290
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
291
+
292
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
293
+
294
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
295
+
296
+
297
+
298
+ Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein:
299
+ ---------------------------------------------
300
+ 1. tqdm
301
+ Copyright (c) 2013 noamraph
302
+
303
+ `tqdm` is a product of collaborative work.
304
+ Unless otherwise stated, all authors (see commit logs) retain copyright
305
+ for their respective work, and release the work under the MIT licence
306
+ (text below).
307
+
308
+ Exceptions or notable authors are listed below
309
+ in reverse chronological order:
310
+
311
+ * files: *
312
+ MPLv2.0 2015-2020 (c) Casper da Costa-Luis
313
+ [casperdcl](https://github.com/casperdcl).
314
+ * files: tqdm/_tqdm.py
315
+ MIT 2016 (c) [PR #96] on behalf of Google Inc.
316
+ * files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
317
+ MIT 2013 (c) Noam Yorav-Raphael, original author.
318
+
319
+ [PR #96]: https://github.com/tqdm/tqdm/pull/96
320
+
321
+
322
+ Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
323
+ -----------------------------------------------
324
+
325
+ This Source Code Form is subject to the terms of the
326
+ Mozilla Public License, v. 2.0.
327
+ If a copy of the MPL was not distributed with this file,
328
+ You can obtain one at https://mozilla.org/MPL/2.0/.
329
+
330
+
331
+ MIT License (MIT)
332
+ -----------------
333
+
334
+ Copyright (c) 2013 noamraph
335
+
336
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
337
+ this software and associated documentation files (the "Software"), to deal in
338
+ the Software without restriction, including without limitation the rights to
339
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
340
+ the Software, and to permit persons to whom the Software is furnished to do so,
341
+ subject to the following conditions:
342
+
343
+ The above copyright notice and this permission notice shall be included in all
344
+ copies or substantial portions of the Software.
345
+
346
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
347
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
348
+ FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
349
+ COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
350
+ IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
351
+ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
comfy_extras/chainner_models/architecture/face/LICENSE-RestoreFormer ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the open source community by making GFPGAN available.
2
+
3
+ Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
+
5
+ GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below.
6
+
7
+
8
+ Terms of the Apache License Version 2.0:
9
+ ---------------------------------------------
10
+ Apache License
11
+
12
+ Version 2.0, January 2004
13
+
14
+ http://www.apache.org/licenses/
15
+
16
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
17
+ 1. Definitions.
18
+
19
+ “License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
20
+
21
+ “Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
22
+
23
+ “Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ “You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
26
+
27
+ “Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
28
+
29
+ “Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
30
+
31
+ “Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
32
+
33
+ “Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
34
+
35
+ “Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
36
+
37
+ “Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
38
+
39
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
40
+
41
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
42
+
43
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
44
+
45
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
46
+
47
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
48
+
49
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
50
+
51
+ If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
52
+
53
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
54
+
55
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
56
+
57
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
58
+
59
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
60
+
61
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
62
+
63
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
64
+
65
+ END OF TERMS AND CONDITIONS
66
+
67
+
68
+
69
+ Other dependencies and licenses:
70
+
71
+
72
+ Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein:
73
+ ---------------------------------------------
74
+ 1. basicsr
75
+ Copyright 2018-2020 BasicSR Authors
76
+
77
+
78
+ This BasicSR project is released under the Apache 2.0 license.
79
+
80
+ A copy of Apache 2.0 is included in this file.
81
+
82
+ StyleGAN2
83
+ The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch.
84
+ The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license.
85
+ DFDNet
86
+ The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
87
+
88
+ Terms of the Nvidia License:
89
+ ---------------------------------------------
90
+
91
+ 1. Definitions
92
+
93
+ "Licensor" means any person or entity that distributes its Work.
94
+
95
+ "Software" means the original work of authorship made available under
96
+ this License.
97
+
98
+ "Work" means the Software and any additions to or derivative works of
99
+ the Software that are made available under this License.
100
+
101
+ "Nvidia Processors" means any central processing unit (CPU), graphics
102
+ processing unit (GPU), field-programmable gate array (FPGA),
103
+ application-specific integrated circuit (ASIC) or any combination
104
+ thereof designed, made, sold, or provided by Nvidia or its affiliates.
105
+
106
+ The terms "reproduce," "reproduction," "derivative works," and
107
+ "distribution" have the meaning as provided under U.S. copyright law;
108
+ provided, however, that for the purposes of this License, derivative
109
+ works shall not include works that remain separable from, or merely
110
+ link (or bind by name) to the interfaces of, the Work.
111
+
112
+ Works, including the Software, are "made available" under this License
113
+ by including in or with the Work either (a) a copyright notice
114
+ referencing the applicability of this License to the Work, or (b) a
115
+ copy of this License.
116
+
117
+ 2. License Grants
118
+
119
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
120
+ License, each Licensor grants to you a perpetual, worldwide,
121
+ non-exclusive, royalty-free, copyright license to reproduce,
122
+ prepare derivative works of, publicly display, publicly perform,
123
+ sublicense and distribute its Work and any resulting derivative
124
+ works in any form.
125
+
126
+ 3. Limitations
127
+
128
+ 3.1 Redistribution. You may reproduce or distribute the Work only
129
+ if (a) you do so under this License, (b) you include a complete
130
+ copy of this License with your distribution, and (c) you retain
131
+ without modification any copyright, patent, trademark, or
132
+ attribution notices that are present in the Work.
133
+
134
+ 3.2 Derivative Works. You may specify that additional or different
135
+ terms apply to the use, reproduction, and distribution of your
136
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
137
+ provide that the use limitation in Section 3.3 applies to your
138
+ derivative works, and (b) you identify the specific derivative
139
+ works that are subject to Your Terms. Notwithstanding Your Terms,
140
+ this License (including the redistribution requirements in Section
141
+ 3.1) will continue to apply to the Work itself.
142
+
143
+ 3.3 Use Limitation. The Work and any derivative works thereof only
144
+ may be used or intended for use non-commercially. The Work or
145
+ derivative works thereof may be used or intended for use by Nvidia
146
+ or its affiliates commercially or non-commercially. As used herein,
147
+ "non-commercially" means for research or evaluation purposes only.
148
+
149
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
150
+ against any Licensor (including any claim, cross-claim or
151
+ counterclaim in a lawsuit) to enforce any patents that you allege
152
+ are infringed by any Work, then your rights under this License from
153
+ such Licensor (including the grants in Sections 2.1 and 2.2) will
154
+ terminate immediately.
155
+
156
+ 3.5 Trademarks. This License does not grant any rights to use any
157
+ Licensor's or its affiliates' names, logos, or trademarks, except
158
+ as necessary to reproduce the notices described in this License.
159
+
160
+ 3.6 Termination. If you violate any term of this License, then your
161
+ rights under this License (including the grants in Sections 2.1 and
162
+ 2.2) will terminate immediately.
163
+
164
+ 4. Disclaimer of Warranty.
165
+
166
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
167
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
168
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
169
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
170
+ THIS LICENSE.
171
+
172
+ 5. Limitation of Liability.
173
+
174
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
175
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
176
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
177
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
178
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
179
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
180
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
181
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
182
+ THE POSSIBILITY OF SUCH DAMAGES.
183
+
184
+ MIT License
185
+
186
+ Copyright (c) 2019 Kim Seonghyeon
187
+
188
+ Permission is hereby granted, free of charge, to any person obtaining a copy
189
+ of this software and associated documentation files (the "Software"), to deal
190
+ in the Software without restriction, including without limitation the rights
191
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
192
+ copies of the Software, and to permit persons to whom the Software is
193
+ furnished to do so, subject to the following conditions:
194
+
195
+ The above copyright notice and this permission notice shall be included in all
196
+ copies or substantial portions of the Software.
197
+
198
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
199
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
200
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
201
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
202
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
203
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
204
+ SOFTWARE.
205
+
206
+
207
+
208
+ Open Source Software licensed under the BSD 3-Clause license:
209
+ ---------------------------------------------
210
+ 1. torchvision
211
+ Copyright (c) Soumith Chintala 2016,
212
+ All rights reserved.
213
+
214
+ 2. torch
215
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
216
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
217
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
218
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
219
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
220
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
221
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
222
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
223
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
224
+
225
+
226
+ Terms of the BSD 3-Clause License:
227
+ ---------------------------------------------
228
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
229
+
230
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
231
+
232
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
233
+
234
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
235
+
236
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
237
+
238
+
239
+
240
+ Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
241
+ ---------------------------------------------
242
+ 1. numpy
243
+ Copyright (c) 2005-2020, NumPy Developers.
244
+ All rights reserved.
245
+
246
+ A copy of BSD 3-Clause License is included in this file.
247
+
248
+ The NumPy repository and source distributions bundle several libraries that are
249
+ compatibly licensed. We list these here.
250
+
251
+ Name: Numpydoc
252
+ Files: doc/sphinxext/numpydoc/*
253
+ License: BSD-2-Clause
254
+ For details, see doc/sphinxext/LICENSE.txt
255
+
256
+ Name: scipy-sphinx-theme
257
+ Files: doc/scipy-sphinx-theme/*
258
+ License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0
259
+ For details, see doc/scipy-sphinx-theme/LICENSE.txt
260
+
261
+ Name: lapack-lite
262
+ Files: numpy/linalg/lapack_lite/*
263
+ License: BSD-3-Clause
264
+ For details, see numpy/linalg/lapack_lite/LICENSE.txt
265
+
266
+ Name: tempita
267
+ Files: tools/npy_tempita/*
268
+ License: MIT
269
+ For details, see tools/npy_tempita/license.txt
270
+
271
+ Name: dragon4
272
+ Files: numpy/core/src/multiarray/dragon4.c
273
+ License: MIT
274
+ For license text, see numpy/core/src/multiarray/dragon4.c
275
+
276
+
277
+
278
+ Open Source Software licensed under the MIT license:
279
+ ---------------------------------------------
280
+ 1. facexlib
281
+ Copyright (c) 2020 Xintao Wang
282
+
283
+ 2. opencv-python
284
+ Copyright (c) Olli-Pekka Heinisuo
285
+ Please note that only files in cv2 package are used.
286
+
287
+
288
+ Terms of the MIT License:
289
+ ---------------------------------------------
290
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
291
+
292
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
293
+
294
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
295
+
296
+
297
+
298
+ Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein:
299
+ ---------------------------------------------
300
+ 1. tqdm
301
+ Copyright (c) 2013 noamraph
302
+
303
+ `tqdm` is a product of collaborative work.
304
+ Unless otherwise stated, all authors (see commit logs) retain copyright
305
+ for their respective work, and release the work under the MIT licence
306
+ (text below).
307
+
308
+ Exceptions or notable authors are listed below
309
+ in reverse chronological order:
310
+
311
+ * files: *
312
+ MPLv2.0 2015-2020 (c) Casper da Costa-Luis
313
+ [casperdcl](https://github.com/casperdcl).
314
+ * files: tqdm/_tqdm.py
315
+ MIT 2016 (c) [PR #96] on behalf of Google Inc.
316
+ * files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
317
+ MIT 2013 (c) Noam Yorav-Raphael, original author.
318
+
319
+ [PR #96]: https://github.com/tqdm/tqdm/pull/96
320
+
321
+
322
+ Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
323
+ -----------------------------------------------
324
+
325
+ This Source Code Form is subject to the terms of the
326
+ Mozilla Public License, v. 2.0.
327
+ If a copy of the MPL was not distributed with this file,
328
+ You can obtain one at https://mozilla.org/MPL/2.0/.
329
+
330
+
331
+ MIT License (MIT)
332
+ -----------------
333
+
334
+ Copyright (c) 2013 noamraph
335
+
336
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
337
+ this software and associated documentation files (the "Software"), to deal in
338
+ the Software without restriction, including without limitation the rights to
339
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
340
+ the Software, and to permit persons to whom the Software is furnished to do so,
341
+ subject to the following conditions:
342
+
343
+ The above copyright notice and this permission notice shall be included in all
344
+ copies or substantial portions of the Software.
345
+
346
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
347
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
348
+ FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
349
+ COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
350
+ IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
351
+ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
comfy_extras/chainner_models/architecture/face/LICENSE-codeformer ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ S-Lab License 1.0
2
+
3
+ Copyright 2022 S-Lab
4
+
5
+ Redistribution and use for non-commercial purpose in source and
6
+ binary forms, with or without modification, are permitted provided
7
+ that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright
10
+ notice, this list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright
13
+ notice, this list of conditions and the following disclaimer in
14
+ the documentation and/or other materials provided with the
15
+ distribution.
16
+
17
+ 3. Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived
19
+ from this software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
+
33
+ In the event that redistribution and/or use for commercial purpose in
34
+ source or binary forms, with or without modification is required,
35
+ please contact the contributor(s) of the work.
comfy_extras/chainner_models/architecture/face/arcface_arch.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def conv3x3(inplanes, outplanes, stride=1):
5
+ """A simple wrapper for 3x3 convolution with padding.
6
+
7
+ Args:
8
+ inplanes (int): Channel number of inputs.
9
+ outplanes (int): Channel number of outputs.
10
+ stride (int): Stride in convolution. Default: 1.
11
+ """
12
+ return nn.Conv2d(
13
+ inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False
14
+ )
15
+
16
+
17
+ class BasicBlock(nn.Module):
18
+ """Basic residual block used in the ResNetArcFace architecture.
19
+
20
+ Args:
21
+ inplanes (int): Channel number of inputs.
22
+ planes (int): Channel number of outputs.
23
+ stride (int): Stride in convolution. Default: 1.
24
+ downsample (nn.Module): The downsample module. Default: None.
25
+ """
26
+
27
+ expansion = 1 # output channel expansion ratio
28
+
29
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
30
+ super(BasicBlock, self).__init__()
31
+ self.conv1 = conv3x3(inplanes, planes, stride)
32
+ self.bn1 = nn.BatchNorm2d(planes)
33
+ self.relu = nn.ReLU(inplace=True)
34
+ self.conv2 = conv3x3(planes, planes)
35
+ self.bn2 = nn.BatchNorm2d(planes)
36
+ self.downsample = downsample
37
+ self.stride = stride
38
+
39
+ def forward(self, x):
40
+ residual = x
41
+
42
+ out = self.conv1(x)
43
+ out = self.bn1(out)
44
+ out = self.relu(out)
45
+
46
+ out = self.conv2(out)
47
+ out = self.bn2(out)
48
+
49
+ if self.downsample is not None:
50
+ residual = self.downsample(x)
51
+
52
+ out += residual
53
+ out = self.relu(out)
54
+
55
+ return out
56
+
57
+
58
+ class IRBlock(nn.Module):
59
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
60
+
61
+ Args:
62
+ inplanes (int): Channel number of inputs.
63
+ planes (int): Channel number of outputs.
64
+ stride (int): Stride in convolution. Default: 1.
65
+ downsample (nn.Module): The downsample module. Default: None.
66
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
67
+ """
68
+
69
+ expansion = 1 # output channel expansion ratio
70
+
71
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
72
+ super(IRBlock, self).__init__()
73
+ self.bn0 = nn.BatchNorm2d(inplanes)
74
+ self.conv1 = conv3x3(inplanes, inplanes)
75
+ self.bn1 = nn.BatchNorm2d(inplanes)
76
+ self.prelu = nn.PReLU()
77
+ self.conv2 = conv3x3(inplanes, planes, stride)
78
+ self.bn2 = nn.BatchNorm2d(planes)
79
+ self.downsample = downsample
80
+ self.stride = stride
81
+ self.use_se = use_se
82
+ if self.use_se:
83
+ self.se = SEBlock(planes)
84
+
85
+ def forward(self, x):
86
+ residual = x
87
+ out = self.bn0(x)
88
+ out = self.conv1(out)
89
+ out = self.bn1(out)
90
+ out = self.prelu(out)
91
+
92
+ out = self.conv2(out)
93
+ out = self.bn2(out)
94
+ if self.use_se:
95
+ out = self.se(out)
96
+
97
+ if self.downsample is not None:
98
+ residual = self.downsample(x)
99
+
100
+ out += residual
101
+ out = self.prelu(out)
102
+
103
+ return out
104
+
105
+
106
+ class Bottleneck(nn.Module):
107
+ """Bottleneck block used in the ResNetArcFace architecture.
108
+
109
+ Args:
110
+ inplanes (int): Channel number of inputs.
111
+ planes (int): Channel number of outputs.
112
+ stride (int): Stride in convolution. Default: 1.
113
+ downsample (nn.Module): The downsample module. Default: None.
114
+ """
115
+
116
+ expansion = 4 # output channel expansion ratio
117
+
118
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
119
+ super(Bottleneck, self).__init__()
120
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
121
+ self.bn1 = nn.BatchNorm2d(planes)
122
+ self.conv2 = nn.Conv2d(
123
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
124
+ )
125
+ self.bn2 = nn.BatchNorm2d(planes)
126
+ self.conv3 = nn.Conv2d(
127
+ planes, planes * self.expansion, kernel_size=1, bias=False
128
+ )
129
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
130
+ self.relu = nn.ReLU(inplace=True)
131
+ self.downsample = downsample
132
+ self.stride = stride
133
+
134
+ def forward(self, x):
135
+ residual = x
136
+
137
+ out = self.conv1(x)
138
+ out = self.bn1(out)
139
+ out = self.relu(out)
140
+
141
+ out = self.conv2(out)
142
+ out = self.bn2(out)
143
+ out = self.relu(out)
144
+
145
+ out = self.conv3(out)
146
+ out = self.bn3(out)
147
+
148
+ if self.downsample is not None:
149
+ residual = self.downsample(x)
150
+
151
+ out += residual
152
+ out = self.relu(out)
153
+
154
+ return out
155
+
156
+
157
+ class SEBlock(nn.Module):
158
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
159
+
160
+ Args:
161
+ channel (int): Channel number of inputs.
162
+ reduction (int): Channel reduction ration. Default: 16.
163
+ """
164
+
165
+ def __init__(self, channel, reduction=16):
166
+ super(SEBlock, self).__init__()
167
+ self.avg_pool = nn.AdaptiveAvgPool2d(
168
+ 1
169
+ ) # pool to 1x1 without spatial information
170
+ self.fc = nn.Sequential(
171
+ nn.Linear(channel, channel // reduction),
172
+ nn.PReLU(),
173
+ nn.Linear(channel // reduction, channel),
174
+ nn.Sigmoid(),
175
+ )
176
+
177
+ def forward(self, x):
178
+ b, c, _, _ = x.size()
179
+ y = self.avg_pool(x).view(b, c)
180
+ y = self.fc(y).view(b, c, 1, 1)
181
+ return x * y
182
+
183
+
184
+ class ResNetArcFace(nn.Module):
185
+ """ArcFace with ResNet architectures.
186
+
187
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
188
+
189
+ Args:
190
+ block (str): Block used in the ArcFace architecture.
191
+ layers (tuple(int)): Block numbers in each layer.
192
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
193
+ """
194
+
195
+ def __init__(self, block, layers, use_se=True):
196
+ if block == "IRBlock":
197
+ block = IRBlock
198
+ self.inplanes = 64
199
+ self.use_se = use_se
200
+ super(ResNetArcFace, self).__init__()
201
+
202
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
203
+ self.bn1 = nn.BatchNorm2d(64)
204
+ self.prelu = nn.PReLU()
205
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
206
+ self.layer1 = self._make_layer(block, 64, layers[0])
207
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
208
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
209
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
210
+ self.bn4 = nn.BatchNorm2d(512)
211
+ self.dropout = nn.Dropout()
212
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
213
+ self.bn5 = nn.BatchNorm1d(512)
214
+
215
+ # initialization
216
+ for m in self.modules():
217
+ if isinstance(m, nn.Conv2d):
218
+ nn.init.xavier_normal_(m.weight)
219
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
220
+ nn.init.constant_(m.weight, 1)
221
+ nn.init.constant_(m.bias, 0)
222
+ elif isinstance(m, nn.Linear):
223
+ nn.init.xavier_normal_(m.weight)
224
+ nn.init.constant_(m.bias, 0)
225
+
226
+ def _make_layer(self, block, planes, num_blocks, stride=1):
227
+ downsample = None
228
+ if stride != 1 or self.inplanes != planes * block.expansion:
229
+ downsample = nn.Sequential(
230
+ nn.Conv2d(
231
+ self.inplanes,
232
+ planes * block.expansion,
233
+ kernel_size=1,
234
+ stride=stride,
235
+ bias=False,
236
+ ),
237
+ nn.BatchNorm2d(planes * block.expansion),
238
+ )
239
+ layers = []
240
+ layers.append(
241
+ block(self.inplanes, planes, stride, downsample, use_se=self.use_se)
242
+ )
243
+ self.inplanes = planes
244
+ for _ in range(1, num_blocks):
245
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
246
+
247
+ return nn.Sequential(*layers)
248
+
249
+ def forward(self, x):
250
+ x = self.conv1(x)
251
+ x = self.bn1(x)
252
+ x = self.prelu(x)
253
+ x = self.maxpool(x)
254
+
255
+ x = self.layer1(x)
256
+ x = self.layer2(x)
257
+ x = self.layer3(x)
258
+ x = self.layer4(x)
259
+ x = self.bn4(x)
260
+ x = self.dropout(x)
261
+ x = x.view(x.size(0), -1)
262
+ x = self.fc5(x)
263
+ x = self.bn5(x)
264
+
265
+ return x
comfy_extras/chainner_models/architecture/face/codeformer.py ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from https://github.com/sczhou/CodeFormer
3
+ VQGAN code, adapted from the original created by the Unleashing Transformers authors:
4
+ https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
5
+ This verison of the arch specifically was gathered from an old version of GFPGAN. If this is a problem, please contact me.
6
+ """
7
+ import math
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import logging as logger
14
+ from torch import Tensor
15
+
16
+
17
+ class VectorQuantizer(nn.Module):
18
+ def __init__(self, codebook_size, emb_dim, beta):
19
+ super(VectorQuantizer, self).__init__()
20
+ self.codebook_size = codebook_size # number of embeddings
21
+ self.emb_dim = emb_dim # dimension of embedding
22
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
23
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
24
+ self.embedding.weight.data.uniform_(
25
+ -1.0 / self.codebook_size, 1.0 / self.codebook_size
26
+ )
27
+
28
+ def forward(self, z):
29
+ # reshape z -> (batch, height, width, channel) and flatten
30
+ z = z.permute(0, 2, 3, 1).contiguous()
31
+ z_flattened = z.view(-1, self.emb_dim)
32
+
33
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
34
+ d = (
35
+ (z_flattened**2).sum(dim=1, keepdim=True)
36
+ + (self.embedding.weight**2).sum(1)
37
+ - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
38
+ )
39
+
40
+ mean_distance = torch.mean(d)
41
+ # find closest encodings
42
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
43
+ min_encoding_scores, min_encoding_indices = torch.topk(
44
+ d, 1, dim=1, largest=False
45
+ )
46
+ # [0-1], higher score, higher confidence
47
+ min_encoding_scores = torch.exp(-min_encoding_scores / 10)
48
+
49
+ min_encodings = torch.zeros(
50
+ min_encoding_indices.shape[0], self.codebook_size
51
+ ).to(z)
52
+ min_encodings.scatter_(1, min_encoding_indices, 1)
53
+
54
+ # get quantized latent vectors
55
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
56
+ # compute loss for embedding
57
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
58
+ (z_q - z.detach()) ** 2
59
+ )
60
+ # preserve gradients
61
+ z_q = z + (z_q - z).detach()
62
+
63
+ # perplexity
64
+ e_mean = torch.mean(min_encodings, dim=0)
65
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
66
+ # reshape back to match original input shape
67
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
68
+
69
+ return (
70
+ z_q,
71
+ loss,
72
+ {
73
+ "perplexity": perplexity,
74
+ "min_encodings": min_encodings,
75
+ "min_encoding_indices": min_encoding_indices,
76
+ "min_encoding_scores": min_encoding_scores,
77
+ "mean_distance": mean_distance,
78
+ },
79
+ )
80
+
81
+ def get_codebook_feat(self, indices, shape):
82
+ # input indices: batch*token_num -> (batch*token_num)*1
83
+ # shape: batch, height, width, channel
84
+ indices = indices.view(-1, 1)
85
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
86
+ min_encodings.scatter_(1, indices, 1)
87
+ # get quantized latent vectors
88
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
89
+
90
+ if shape is not None: # reshape back to match original input shape
91
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
92
+
93
+ return z_q
94
+
95
+
96
+ class GumbelQuantizer(nn.Module):
97
+ def __init__(
98
+ self,
99
+ codebook_size,
100
+ emb_dim,
101
+ num_hiddens,
102
+ straight_through=False,
103
+ kl_weight=5e-4,
104
+ temp_init=1.0,
105
+ ):
106
+ super().__init__()
107
+ self.codebook_size = codebook_size # number of embeddings
108
+ self.emb_dim = emb_dim # dimension of embedding
109
+ self.straight_through = straight_through
110
+ self.temperature = temp_init
111
+ self.kl_weight = kl_weight
112
+ self.proj = nn.Conv2d(
113
+ num_hiddens, codebook_size, 1
114
+ ) # projects last encoder layer to quantized logits
115
+ self.embed = nn.Embedding(codebook_size, emb_dim)
116
+
117
+ def forward(self, z):
118
+ hard = self.straight_through if self.training else True
119
+
120
+ logits = self.proj(z)
121
+
122
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
123
+
124
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
125
+
126
+ # + kl divergence to the prior loss
127
+ qy = F.softmax(logits, dim=1)
128
+ diff = (
129
+ self.kl_weight
130
+ * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
131
+ )
132
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
133
+
134
+ return z_q, diff, {"min_encoding_indices": min_encoding_indices}
135
+
136
+
137
+ class Downsample(nn.Module):
138
+ def __init__(self, in_channels):
139
+ super().__init__()
140
+ self.conv = torch.nn.Conv2d(
141
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
142
+ )
143
+
144
+ def forward(self, x):
145
+ pad = (0, 1, 0, 1)
146
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
147
+ x = self.conv(x)
148
+ return x
149
+
150
+
151
+ class Upsample(nn.Module):
152
+ def __init__(self, in_channels):
153
+ super().__init__()
154
+ self.conv = nn.Conv2d(
155
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
156
+ )
157
+
158
+ def forward(self, x):
159
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
160
+ x = self.conv(x)
161
+
162
+ return x
163
+
164
+
165
+ class AttnBlock(nn.Module):
166
+ def __init__(self, in_channels):
167
+ super().__init__()
168
+ self.in_channels = in_channels
169
+
170
+ self.norm = normalize(in_channels)
171
+ self.q = torch.nn.Conv2d(
172
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
173
+ )
174
+ self.k = torch.nn.Conv2d(
175
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
176
+ )
177
+ self.v = torch.nn.Conv2d(
178
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
179
+ )
180
+ self.proj_out = torch.nn.Conv2d(
181
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
182
+ )
183
+
184
+ def forward(self, x):
185
+ h_ = x
186
+ h_ = self.norm(h_)
187
+ q = self.q(h_)
188
+ k = self.k(h_)
189
+ v = self.v(h_)
190
+
191
+ # compute attention
192
+ b, c, h, w = q.shape
193
+ q = q.reshape(b, c, h * w)
194
+ q = q.permute(0, 2, 1)
195
+ k = k.reshape(b, c, h * w)
196
+ w_ = torch.bmm(q, k)
197
+ w_ = w_ * (int(c) ** (-0.5))
198
+ w_ = F.softmax(w_, dim=2)
199
+
200
+ # attend to values
201
+ v = v.reshape(b, c, h * w)
202
+ w_ = w_.permute(0, 2, 1)
203
+ h_ = torch.bmm(v, w_)
204
+ h_ = h_.reshape(b, c, h, w)
205
+
206
+ h_ = self.proj_out(h_)
207
+
208
+ return x + h_
209
+
210
+
211
+ class Encoder(nn.Module):
212
+ def __init__(
213
+ self,
214
+ in_channels,
215
+ nf,
216
+ out_channels,
217
+ ch_mult,
218
+ num_res_blocks,
219
+ resolution,
220
+ attn_resolutions,
221
+ ):
222
+ super().__init__()
223
+ self.nf = nf
224
+ self.num_resolutions = len(ch_mult)
225
+ self.num_res_blocks = num_res_blocks
226
+ self.resolution = resolution
227
+ self.attn_resolutions = attn_resolutions
228
+
229
+ curr_res = self.resolution
230
+ in_ch_mult = (1,) + tuple(ch_mult)
231
+
232
+ blocks = []
233
+ # initial convultion
234
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
235
+
236
+ # residual and downsampling blocks, with attention on smaller res (16x16)
237
+ for i in range(self.num_resolutions):
238
+ block_in_ch = nf * in_ch_mult[i]
239
+ block_out_ch = nf * ch_mult[i]
240
+ for _ in range(self.num_res_blocks):
241
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
242
+ block_in_ch = block_out_ch
243
+ if curr_res in attn_resolutions:
244
+ blocks.append(AttnBlock(block_in_ch))
245
+
246
+ if i != self.num_resolutions - 1:
247
+ blocks.append(Downsample(block_in_ch))
248
+ curr_res = curr_res // 2
249
+
250
+ # non-local attention block
251
+ blocks.append(ResBlock(block_in_ch, block_in_ch)) # type: ignore
252
+ blocks.append(AttnBlock(block_in_ch)) # type: ignore
253
+ blocks.append(ResBlock(block_in_ch, block_in_ch)) # type: ignore
254
+
255
+ # normalise and convert to latent size
256
+ blocks.append(normalize(block_in_ch)) # type: ignore
257
+ blocks.append(
258
+ nn.Conv2d(block_in_ch, out_channels, kernel_size=3, stride=1, padding=1) # type: ignore
259
+ )
260
+ self.blocks = nn.ModuleList(blocks)
261
+
262
+ def forward(self, x):
263
+ for block in self.blocks:
264
+ x = block(x)
265
+
266
+ return x
267
+
268
+
269
+ class Generator(nn.Module):
270
+ def __init__(self, nf, ch_mult, res_blocks, img_size, attn_resolutions, emb_dim):
271
+ super().__init__()
272
+ self.nf = nf
273
+ self.ch_mult = ch_mult
274
+ self.num_resolutions = len(self.ch_mult)
275
+ self.num_res_blocks = res_blocks
276
+ self.resolution = img_size
277
+ self.attn_resolutions = attn_resolutions
278
+ self.in_channels = emb_dim
279
+ self.out_channels = 3
280
+ block_in_ch = self.nf * self.ch_mult[-1]
281
+ curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
282
+
283
+ blocks = []
284
+ # initial conv
285
+ blocks.append(
286
+ nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
287
+ )
288
+
289
+ # non-local attention block
290
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
291
+ blocks.append(AttnBlock(block_in_ch))
292
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
293
+
294
+ for i in reversed(range(self.num_resolutions)):
295
+ block_out_ch = self.nf * self.ch_mult[i]
296
+
297
+ for _ in range(self.num_res_blocks):
298
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
299
+ block_in_ch = block_out_ch
300
+
301
+ if curr_res in self.attn_resolutions:
302
+ blocks.append(AttnBlock(block_in_ch))
303
+
304
+ if i != 0:
305
+ blocks.append(Upsample(block_in_ch))
306
+ curr_res = curr_res * 2
307
+
308
+ blocks.append(normalize(block_in_ch))
309
+ blocks.append(
310
+ nn.Conv2d(
311
+ block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
312
+ )
313
+ )
314
+
315
+ self.blocks = nn.ModuleList(blocks)
316
+
317
+ def forward(self, x):
318
+ for block in self.blocks:
319
+ x = block(x)
320
+
321
+ return x
322
+
323
+
324
+ class VQAutoEncoder(nn.Module):
325
+ def __init__(
326
+ self,
327
+ img_size,
328
+ nf,
329
+ ch_mult,
330
+ quantizer="nearest",
331
+ res_blocks=2,
332
+ attn_resolutions=[16],
333
+ codebook_size=1024,
334
+ emb_dim=256,
335
+ beta=0.25,
336
+ gumbel_straight_through=False,
337
+ gumbel_kl_weight=1e-8,
338
+ model_path=None,
339
+ ):
340
+ super().__init__()
341
+ self.in_channels = 3
342
+ self.nf = nf
343
+ self.n_blocks = res_blocks
344
+ self.codebook_size = codebook_size
345
+ self.embed_dim = emb_dim
346
+ self.ch_mult = ch_mult
347
+ self.resolution = img_size
348
+ self.attn_resolutions = attn_resolutions
349
+ self.quantizer_type = quantizer
350
+ self.encoder = Encoder(
351
+ self.in_channels,
352
+ self.nf,
353
+ self.embed_dim,
354
+ self.ch_mult,
355
+ self.n_blocks,
356
+ self.resolution,
357
+ self.attn_resolutions,
358
+ )
359
+ if self.quantizer_type == "nearest":
360
+ self.beta = beta # 0.25
361
+ self.quantize = VectorQuantizer(
362
+ self.codebook_size, self.embed_dim, self.beta
363
+ )
364
+ elif self.quantizer_type == "gumbel":
365
+ self.gumbel_num_hiddens = emb_dim
366
+ self.straight_through = gumbel_straight_through
367
+ self.kl_weight = gumbel_kl_weight
368
+ self.quantize = GumbelQuantizer(
369
+ self.codebook_size,
370
+ self.embed_dim,
371
+ self.gumbel_num_hiddens,
372
+ self.straight_through,
373
+ self.kl_weight,
374
+ )
375
+ self.generator = Generator(
376
+ nf, ch_mult, res_blocks, img_size, attn_resolutions, emb_dim
377
+ )
378
+
379
+ if model_path is not None:
380
+ chkpt = torch.load(model_path, map_location="cpu")
381
+ if "params_ema" in chkpt:
382
+ self.load_state_dict(
383
+ torch.load(model_path, map_location="cpu")["params_ema"]
384
+ )
385
+ logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
386
+ elif "params" in chkpt:
387
+ self.load_state_dict(
388
+ torch.load(model_path, map_location="cpu")["params"]
389
+ )
390
+ logger.info(f"vqgan is loaded from: {model_path} [params]")
391
+ else:
392
+ raise ValueError("Wrong params!")
393
+
394
+ def forward(self, x):
395
+ x = self.encoder(x)
396
+ quant, codebook_loss, quant_stats = self.quantize(x)
397
+ x = self.generator(quant)
398
+ return x, codebook_loss, quant_stats
399
+
400
+
401
+ def calc_mean_std(feat, eps=1e-5):
402
+ """Calculate mean and std for adaptive_instance_normalization.
403
+ Args:
404
+ feat (Tensor): 4D tensor.
405
+ eps (float): A small value added to the variance to avoid
406
+ divide-by-zero. Default: 1e-5.
407
+ """
408
+ size = feat.size()
409
+ assert len(size) == 4, "The input feature should be 4D tensor."
410
+ b, c = size[:2]
411
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
412
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
413
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
414
+ return feat_mean, feat_std
415
+
416
+
417
+ def adaptive_instance_normalization(content_feat, style_feat):
418
+ """Adaptive instance normalization.
419
+ Adjust the reference features to have the similar color and illuminations
420
+ as those in the degradate features.
421
+ Args:
422
+ content_feat (Tensor): The reference feature.
423
+ style_feat (Tensor): The degradate features.
424
+ """
425
+ size = content_feat.size()
426
+ style_mean, style_std = calc_mean_std(style_feat)
427
+ content_mean, content_std = calc_mean_std(content_feat)
428
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
429
+ size
430
+ )
431
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
432
+
433
+
434
+ class PositionEmbeddingSine(nn.Module):
435
+ """
436
+ This is a more standard version of the position embedding, very similar to the one
437
+ used by the Attention is all you need paper, generalized to work on images.
438
+ """
439
+
440
+ def __init__(
441
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
442
+ ):
443
+ super().__init__()
444
+ self.num_pos_feats = num_pos_feats
445
+ self.temperature = temperature
446
+ self.normalize = normalize
447
+ if scale is not None and normalize is False:
448
+ raise ValueError("normalize should be True if scale is passed")
449
+ if scale is None:
450
+ scale = 2 * math.pi
451
+ self.scale = scale
452
+
453
+ def forward(self, x, mask=None):
454
+ if mask is None:
455
+ mask = torch.zeros(
456
+ (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
457
+ )
458
+ not_mask = ~mask # pylint: disable=invalid-unary-operand-type
459
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
460
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
461
+ if self.normalize:
462
+ eps = 1e-6
463
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
464
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
465
+
466
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
467
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
468
+
469
+ pos_x = x_embed[:, :, :, None] / dim_t
470
+ pos_y = y_embed[:, :, :, None] / dim_t
471
+ pos_x = torch.stack(
472
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
473
+ ).flatten(3)
474
+ pos_y = torch.stack(
475
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
476
+ ).flatten(3)
477
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
478
+ return pos
479
+
480
+
481
+ def _get_activation_fn(activation):
482
+ """Return an activation function given a string"""
483
+ if activation == "relu":
484
+ return F.relu
485
+ if activation == "gelu":
486
+ return F.gelu
487
+ if activation == "glu":
488
+ return F.glu
489
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
490
+
491
+
492
+ class TransformerSALayer(nn.Module):
493
+ def __init__(
494
+ self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
495
+ ):
496
+ super().__init__()
497
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
498
+ # Implementation of Feedforward model - MLP
499
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
500
+ self.dropout = nn.Dropout(dropout)
501
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
502
+
503
+ self.norm1 = nn.LayerNorm(embed_dim)
504
+ self.norm2 = nn.LayerNorm(embed_dim)
505
+ self.dropout1 = nn.Dropout(dropout)
506
+ self.dropout2 = nn.Dropout(dropout)
507
+
508
+ self.activation = _get_activation_fn(activation)
509
+
510
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
511
+ return tensor if pos is None else tensor + pos
512
+
513
+ def forward(
514
+ self,
515
+ tgt,
516
+ tgt_mask: Optional[Tensor] = None,
517
+ tgt_key_padding_mask: Optional[Tensor] = None,
518
+ query_pos: Optional[Tensor] = None,
519
+ ):
520
+ # self attention
521
+ tgt2 = self.norm1(tgt)
522
+ q = k = self.with_pos_embed(tgt2, query_pos)
523
+ tgt2 = self.self_attn(
524
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
525
+ )[0]
526
+ tgt = tgt + self.dropout1(tgt2)
527
+
528
+ # ffn
529
+ tgt2 = self.norm2(tgt)
530
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
531
+ tgt = tgt + self.dropout2(tgt2)
532
+ return tgt
533
+
534
+
535
+ def normalize(in_channels):
536
+ return torch.nn.GroupNorm(
537
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
538
+ )
539
+
540
+
541
+ @torch.jit.script # type: ignore
542
+ def swish(x):
543
+ return x * torch.sigmoid(x)
544
+
545
+
546
+ class ResBlock(nn.Module):
547
+ def __init__(self, in_channels, out_channels=None):
548
+ super(ResBlock, self).__init__()
549
+ self.in_channels = in_channels
550
+ self.out_channels = in_channels if out_channels is None else out_channels
551
+ self.norm1 = normalize(in_channels)
552
+ self.conv1 = nn.Conv2d(
553
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1 # type: ignore
554
+ )
555
+ self.norm2 = normalize(out_channels)
556
+ self.conv2 = nn.Conv2d(
557
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1 # type: ignore
558
+ )
559
+ if self.in_channels != self.out_channels:
560
+ self.conv_out = nn.Conv2d(
561
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0 # type: ignore
562
+ )
563
+
564
+ def forward(self, x_in):
565
+ x = x_in
566
+ x = self.norm1(x)
567
+ x = swish(x)
568
+ x = self.conv1(x)
569
+ x = self.norm2(x)
570
+ x = swish(x)
571
+ x = self.conv2(x)
572
+ if self.in_channels != self.out_channels:
573
+ x_in = self.conv_out(x_in)
574
+
575
+ return x + x_in
576
+
577
+
578
+ class Fuse_sft_block(nn.Module):
579
+ def __init__(self, in_ch, out_ch):
580
+ super().__init__()
581
+ self.encode_enc = ResBlock(2 * in_ch, out_ch)
582
+
583
+ self.scale = nn.Sequential(
584
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
585
+ nn.LeakyReLU(0.2, True),
586
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
587
+ )
588
+
589
+ self.shift = nn.Sequential(
590
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
591
+ nn.LeakyReLU(0.2, True),
592
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
593
+ )
594
+
595
+ def forward(self, enc_feat, dec_feat, w=1):
596
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
597
+ scale = self.scale(enc_feat)
598
+ shift = self.shift(enc_feat)
599
+ residual = w * (dec_feat * scale + shift)
600
+ out = dec_feat + residual
601
+ return out
602
+
603
+
604
+ class CodeFormer(VQAutoEncoder):
605
+ def __init__(self, state_dict):
606
+ dim_embd = 512
607
+ n_head = 8
608
+ n_layers = 9
609
+ codebook_size = 1024
610
+ latent_size = 256
611
+ connect_list = ["32", "64", "128", "256"]
612
+ fix_modules = ["quantize", "generator"]
613
+
614
+ # This is just a guess as I only have one model to look at
615
+ position_emb = state_dict["position_emb"]
616
+ dim_embd = position_emb.shape[1]
617
+ latent_size = position_emb.shape[0]
618
+
619
+ try:
620
+ n_layers = len(
621
+ set([x.split(".")[1] for x in state_dict.keys() if "ft_layers" in x])
622
+ )
623
+ except:
624
+ pass
625
+
626
+ codebook_size = state_dict["quantize.embedding.weight"].shape[0]
627
+
628
+ # This is also just another guess
629
+ n_head_exp = (
630
+ state_dict["ft_layers.0.self_attn.in_proj_weight"].shape[0] // dim_embd
631
+ )
632
+ n_head = 2**n_head_exp
633
+
634
+ in_nc = state_dict["encoder.blocks.0.weight"].shape[1]
635
+
636
+ self.model_arch = "CodeFormer"
637
+ self.sub_type = "Face SR"
638
+ self.scale = 8
639
+ self.in_nc = in_nc
640
+ self.out_nc = in_nc
641
+
642
+ self.state = state_dict
643
+
644
+ self.supports_fp16 = False
645
+ self.supports_bf16 = True
646
+ self.min_size_restriction = 16
647
+
648
+ super(CodeFormer, self).__init__(
649
+ 512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
650
+ )
651
+
652
+ if fix_modules is not None:
653
+ for module in fix_modules:
654
+ for param in getattr(self, module).parameters():
655
+ param.requires_grad = False
656
+
657
+ self.connect_list = connect_list
658
+ self.n_layers = n_layers
659
+ self.dim_embd = dim_embd
660
+ self.dim_mlp = dim_embd * 2
661
+
662
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) # type: ignore
663
+ self.feat_emb = nn.Linear(256, self.dim_embd)
664
+
665
+ # transformer
666
+ self.ft_layers = nn.Sequential(
667
+ *[
668
+ TransformerSALayer(
669
+ embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
670
+ )
671
+ for _ in range(self.n_layers)
672
+ ]
673
+ )
674
+
675
+ # logits_predict head
676
+ self.idx_pred_layer = nn.Sequential(
677
+ nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
678
+ )
679
+
680
+ self.channels = {
681
+ "16": 512,
682
+ "32": 256,
683
+ "64": 256,
684
+ "128": 128,
685
+ "256": 128,
686
+ "512": 64,
687
+ }
688
+
689
+ # after second residual block for > 16, before attn layer for ==16
690
+ self.fuse_encoder_block = {
691
+ "512": 2,
692
+ "256": 5,
693
+ "128": 8,
694
+ "64": 11,
695
+ "32": 14,
696
+ "16": 18,
697
+ }
698
+ # after first residual block for > 16, before attn layer for ==16
699
+ self.fuse_generator_block = {
700
+ "16": 6,
701
+ "32": 9,
702
+ "64": 12,
703
+ "128": 15,
704
+ "256": 18,
705
+ "512": 21,
706
+ }
707
+
708
+ # fuse_convs_dict
709
+ self.fuse_convs_dict = nn.ModuleDict()
710
+ for f_size in self.connect_list:
711
+ in_ch = self.channels[f_size]
712
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
713
+
714
+ self.load_state_dict(state_dict)
715
+
716
+ def _init_weights(self, module):
717
+ if isinstance(module, (nn.Linear, nn.Embedding)):
718
+ module.weight.data.normal_(mean=0.0, std=0.02)
719
+ if isinstance(module, nn.Linear) and module.bias is not None:
720
+ module.bias.data.zero_()
721
+ elif isinstance(module, nn.LayerNorm):
722
+ module.bias.data.zero_()
723
+ module.weight.data.fill_(1.0)
724
+
725
+ def forward(self, x, weight=0.5, **kwargs):
726
+ detach_16 = True
727
+ code_only = False
728
+ adain = True
729
+ # ################### Encoder #####################
730
+ enc_feat_dict = {}
731
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
732
+ for i, block in enumerate(self.encoder.blocks):
733
+ x = block(x)
734
+ if i in out_list:
735
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
736
+
737
+ lq_feat = x
738
+ # ################# Transformer ###################
739
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
740
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
741
+ # BCHW -> BC(HW) -> (HW)BC
742
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
743
+ query_emb = feat_emb
744
+ # Transformer encoder
745
+ for layer in self.ft_layers:
746
+ query_emb = layer(query_emb, query_pos=pos_emb)
747
+
748
+ # output logits
749
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
750
+ logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
751
+
752
+ if code_only: # for training stage II
753
+ # logits doesn't need softmax before cross_entropy loss
754
+ return logits, lq_feat
755
+
756
+ # ################# Quantization ###################
757
+ # if self.training:
758
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
759
+ # # b(hw)c -> bc(hw) -> bchw
760
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
761
+ # ------------
762
+ soft_one_hot = F.softmax(logits, dim=2)
763
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
764
+ quant_feat = self.quantize.get_codebook_feat(
765
+ top_idx, shape=[x.shape[0], 16, 16, 256] # type: ignore
766
+ )
767
+ # preserve gradients
768
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
769
+
770
+ if detach_16:
771
+ quant_feat = quant_feat.detach() # for training stage III
772
+ if adain:
773
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
774
+
775
+ # ################## Generator ####################
776
+ x = quant_feat
777
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
778
+
779
+ for i, block in enumerate(self.generator.blocks):
780
+ x = block(x)
781
+ if i in fuse_list: # fuse after i-th block
782
+ f_size = str(x.shape[-1])
783
+ if weight > 0:
784
+ x = self.fuse_convs_dict[f_size](
785
+ enc_feat_dict[f_size].detach(), x, weight
786
+ )
787
+ out = x
788
+ # logits doesn't need softmax before cross_entropy loss
789
+ # return out, logits, lq_feat
790
+ return out, logits
comfy_extras/chainner_models/architecture/face/fused_act.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.autograd import Function
8
+
9
+ fused_act_ext = None
10
+
11
+
12
+ class FusedLeakyReLUFunctionBackward(Function):
13
+ @staticmethod
14
+ def forward(ctx, grad_output, out, negative_slope, scale):
15
+ ctx.save_for_backward(out)
16
+ ctx.negative_slope = negative_slope
17
+ ctx.scale = scale
18
+
19
+ empty = grad_output.new_empty(0)
20
+
21
+ grad_input = fused_act_ext.fused_bias_act(
22
+ grad_output, empty, out, 3, 1, negative_slope, scale
23
+ )
24
+
25
+ dim = [0]
26
+
27
+ if grad_input.ndim > 2:
28
+ dim += list(range(2, grad_input.ndim))
29
+
30
+ grad_bias = grad_input.sum(dim).detach()
31
+
32
+ return grad_input, grad_bias
33
+
34
+ @staticmethod
35
+ def backward(ctx, gradgrad_input, gradgrad_bias):
36
+ (out,) = ctx.saved_tensors
37
+ gradgrad_out = fused_act_ext.fused_bias_act(
38
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
39
+ )
40
+
41
+ return gradgrad_out, None, None, None
42
+
43
+
44
+ class FusedLeakyReLUFunction(Function):
45
+ @staticmethod
46
+ def forward(ctx, input, bias, negative_slope, scale):
47
+ empty = input.new_empty(0)
48
+ out = fused_act_ext.fused_bias_act(
49
+ input, bias, empty, 3, 0, negative_slope, scale
50
+ )
51
+ ctx.save_for_backward(out)
52
+ ctx.negative_slope = negative_slope
53
+ ctx.scale = scale
54
+
55
+ return out
56
+
57
+ @staticmethod
58
+ def backward(ctx, grad_output):
59
+ (out,) = ctx.saved_tensors
60
+
61
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
62
+ grad_output, out, ctx.negative_slope, ctx.scale
63
+ )
64
+
65
+ return grad_input, grad_bias, None, None
66
+
67
+
68
+ class FusedLeakyReLU(nn.Module):
69
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
70
+ super().__init__()
71
+
72
+ self.bias = nn.Parameter(torch.zeros(channel))
73
+ self.negative_slope = negative_slope
74
+ self.scale = scale
75
+
76
+ def forward(self, input):
77
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
78
+
79
+
80
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
81
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
comfy_extras/chainner_models/architecture/face/gfpgan_bilinear_arch.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ import math
4
+ import random
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from .gfpganv1_arch import ResUpBlock
10
+ from .stylegan2_bilinear_arch import (
11
+ ConvLayer,
12
+ EqualConv2d,
13
+ EqualLinear,
14
+ ResBlock,
15
+ ScaledLeakyReLU,
16
+ StyleGAN2GeneratorBilinear,
17
+ )
18
+
19
+
20
+ class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
21
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
22
+ It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
23
+ deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
24
+ Args:
25
+ out_size (int): The spatial size of outputs.
26
+ num_style_feat (int): Channel number of style features. Default: 512.
27
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
28
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
29
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
30
+ narrow (float): The narrow ratio for channels. Default: 1.
31
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ out_size,
37
+ num_style_feat=512,
38
+ num_mlp=8,
39
+ channel_multiplier=2,
40
+ lr_mlp=0.01,
41
+ narrow=1,
42
+ sft_half=False,
43
+ ):
44
+ super(StyleGAN2GeneratorBilinearSFT, self).__init__(
45
+ out_size,
46
+ num_style_feat=num_style_feat,
47
+ num_mlp=num_mlp,
48
+ channel_multiplier=channel_multiplier,
49
+ lr_mlp=lr_mlp,
50
+ narrow=narrow,
51
+ )
52
+ self.sft_half = sft_half
53
+
54
+ def forward(
55
+ self,
56
+ styles,
57
+ conditions,
58
+ input_is_latent=False,
59
+ noise=None,
60
+ randomize_noise=True,
61
+ truncation=1,
62
+ truncation_latent=None,
63
+ inject_index=None,
64
+ return_latents=False,
65
+ ):
66
+ """Forward function for StyleGAN2GeneratorBilinearSFT.
67
+ Args:
68
+ styles (list[Tensor]): Sample codes of styles.
69
+ conditions (list[Tensor]): SFT conditions to generators.
70
+ input_is_latent (bool): Whether input is latent style. Default: False.
71
+ noise (Tensor | None): Input noise or None. Default: None.
72
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
73
+ truncation (float): The truncation ratio. Default: 1.
74
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
75
+ inject_index (int | None): The injection index for mixing noise. Default: None.
76
+ return_latents (bool): Whether to return style latents. Default: False.
77
+ """
78
+ # style codes -> latents with Style MLP layer
79
+ if not input_is_latent:
80
+ styles = [self.style_mlp(s) for s in styles]
81
+ # noises
82
+ if noise is None:
83
+ if randomize_noise:
84
+ noise = [None] * self.num_layers # for each style conv layer
85
+ else: # use the stored noise
86
+ noise = [
87
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
88
+ ]
89
+ # style truncation
90
+ if truncation < 1:
91
+ style_truncation = []
92
+ for style in styles:
93
+ style_truncation.append(
94
+ truncation_latent + truncation * (style - truncation_latent)
95
+ )
96
+ styles = style_truncation
97
+ # get style latents with injection
98
+ if len(styles) == 1:
99
+ inject_index = self.num_latent
100
+
101
+ if styles[0].ndim < 3:
102
+ # repeat latent code for all the layers
103
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
104
+ else: # used for encoder with different latent code for each layer
105
+ latent = styles[0]
106
+ elif len(styles) == 2: # mixing noises
107
+ if inject_index is None:
108
+ inject_index = random.randint(1, self.num_latent - 1)
109
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
110
+ latent2 = (
111
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
112
+ )
113
+ latent = torch.cat([latent1, latent2], 1)
114
+
115
+ # main generation
116
+ out = self.constant_input(latent.shape[0])
117
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
118
+ skip = self.to_rgb1(out, latent[:, 1])
119
+
120
+ i = 1
121
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
122
+ self.style_convs[::2],
123
+ self.style_convs[1::2],
124
+ noise[1::2],
125
+ noise[2::2],
126
+ self.to_rgbs,
127
+ ):
128
+ out = conv1(out, latent[:, i], noise=noise1)
129
+
130
+ # the conditions may have fewer levels
131
+ if i < len(conditions):
132
+ # SFT part to combine the conditions
133
+ if self.sft_half: # only apply SFT to half of the channels
134
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
135
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
136
+ out = torch.cat([out_same, out_sft], dim=1)
137
+ else: # apply SFT to all the channels
138
+ out = out * conditions[i - 1] + conditions[i]
139
+
140
+ out = conv2(out, latent[:, i + 1], noise=noise2)
141
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
142
+ i += 2
143
+
144
+ image = skip
145
+
146
+ if return_latents:
147
+ return image, latent
148
+ else:
149
+ return image, None
150
+
151
+
152
+ class GFPGANBilinear(nn.Module):
153
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
154
+ It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
155
+ deployment. It can be easily converted to the clean version: GFPGANv1Clean.
156
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
157
+ Args:
158
+ out_size (int): The spatial size of outputs.
159
+ num_style_feat (int): Channel number of style features. Default: 512.
160
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
161
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
162
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
163
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
164
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
165
+ input_is_latent (bool): Whether input is latent style. Default: False.
166
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
167
+ narrow (float): The narrow ratio for channels. Default: 1.
168
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ out_size,
174
+ num_style_feat=512,
175
+ channel_multiplier=1,
176
+ decoder_load_path=None,
177
+ fix_decoder=True,
178
+ # for stylegan decoder
179
+ num_mlp=8,
180
+ lr_mlp=0.01,
181
+ input_is_latent=False,
182
+ different_w=False,
183
+ narrow=1,
184
+ sft_half=False,
185
+ ):
186
+ super(GFPGANBilinear, self).__init__()
187
+ self.input_is_latent = input_is_latent
188
+ self.different_w = different_w
189
+ self.num_style_feat = num_style_feat
190
+ self.min_size_restriction = 512
191
+
192
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
193
+ channels = {
194
+ "4": int(512 * unet_narrow),
195
+ "8": int(512 * unet_narrow),
196
+ "16": int(512 * unet_narrow),
197
+ "32": int(512 * unet_narrow),
198
+ "64": int(256 * channel_multiplier * unet_narrow),
199
+ "128": int(128 * channel_multiplier * unet_narrow),
200
+ "256": int(64 * channel_multiplier * unet_narrow),
201
+ "512": int(32 * channel_multiplier * unet_narrow),
202
+ "1024": int(16 * channel_multiplier * unet_narrow),
203
+ }
204
+
205
+ self.log_size = int(math.log(out_size, 2))
206
+ first_out_size = 2 ** (int(math.log(out_size, 2)))
207
+
208
+ self.conv_body_first = ConvLayer(
209
+ 3, channels[f"{first_out_size}"], 1, bias=True, activate=True
210
+ )
211
+
212
+ # downsample
213
+ in_channels = channels[f"{first_out_size}"]
214
+ self.conv_body_down = nn.ModuleList()
215
+ for i in range(self.log_size, 2, -1):
216
+ out_channels = channels[f"{2**(i - 1)}"]
217
+ self.conv_body_down.append(ResBlock(in_channels, out_channels))
218
+ in_channels = out_channels
219
+
220
+ self.final_conv = ConvLayer(
221
+ in_channels, channels["4"], 3, bias=True, activate=True
222
+ )
223
+
224
+ # upsample
225
+ in_channels = channels["4"]
226
+ self.conv_body_up = nn.ModuleList()
227
+ for i in range(3, self.log_size + 1):
228
+ out_channels = channels[f"{2**i}"]
229
+ self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
230
+ in_channels = out_channels
231
+
232
+ # to RGB
233
+ self.toRGB = nn.ModuleList()
234
+ for i in range(3, self.log_size + 1):
235
+ self.toRGB.append(
236
+ EqualConv2d(
237
+ channels[f"{2**i}"],
238
+ 3,
239
+ 1,
240
+ stride=1,
241
+ padding=0,
242
+ bias=True,
243
+ bias_init_val=0,
244
+ )
245
+ )
246
+
247
+ if different_w:
248
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
249
+ else:
250
+ linear_out_channel = num_style_feat
251
+
252
+ self.final_linear = EqualLinear(
253
+ channels["4"] * 4 * 4,
254
+ linear_out_channel,
255
+ bias=True,
256
+ bias_init_val=0,
257
+ lr_mul=1,
258
+ activation=None,
259
+ )
260
+
261
+ # the decoder: stylegan2 generator with SFT modulations
262
+ self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
263
+ out_size=out_size,
264
+ num_style_feat=num_style_feat,
265
+ num_mlp=num_mlp,
266
+ channel_multiplier=channel_multiplier,
267
+ lr_mlp=lr_mlp,
268
+ narrow=narrow,
269
+ sft_half=sft_half,
270
+ )
271
+
272
+ # load pre-trained stylegan2 model if necessary
273
+ if decoder_load_path:
274
+ self.stylegan_decoder.load_state_dict(
275
+ torch.load(
276
+ decoder_load_path, map_location=lambda storage, loc: storage
277
+ )["params_ema"]
278
+ )
279
+ # fix decoder without updating params
280
+ if fix_decoder:
281
+ for _, param in self.stylegan_decoder.named_parameters():
282
+ param.requires_grad = False
283
+
284
+ # for SFT modulations (scale and shift)
285
+ self.condition_scale = nn.ModuleList()
286
+ self.condition_shift = nn.ModuleList()
287
+ for i in range(3, self.log_size + 1):
288
+ out_channels = channels[f"{2**i}"]
289
+ if sft_half:
290
+ sft_out_channels = out_channels
291
+ else:
292
+ sft_out_channels = out_channels * 2
293
+ self.condition_scale.append(
294
+ nn.Sequential(
295
+ EqualConv2d(
296
+ out_channels,
297
+ out_channels,
298
+ 3,
299
+ stride=1,
300
+ padding=1,
301
+ bias=True,
302
+ bias_init_val=0,
303
+ ),
304
+ ScaledLeakyReLU(0.2),
305
+ EqualConv2d(
306
+ out_channels,
307
+ sft_out_channels,
308
+ 3,
309
+ stride=1,
310
+ padding=1,
311
+ bias=True,
312
+ bias_init_val=1,
313
+ ),
314
+ )
315
+ )
316
+ self.condition_shift.append(
317
+ nn.Sequential(
318
+ EqualConv2d(
319
+ out_channels,
320
+ out_channels,
321
+ 3,
322
+ stride=1,
323
+ padding=1,
324
+ bias=True,
325
+ bias_init_val=0,
326
+ ),
327
+ ScaledLeakyReLU(0.2),
328
+ EqualConv2d(
329
+ out_channels,
330
+ sft_out_channels,
331
+ 3,
332
+ stride=1,
333
+ padding=1,
334
+ bias=True,
335
+ bias_init_val=0,
336
+ ),
337
+ )
338
+ )
339
+
340
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
341
+ """Forward function for GFPGANBilinear.
342
+ Args:
343
+ x (Tensor): Input images.
344
+ return_latents (bool): Whether to return style latents. Default: False.
345
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
346
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
347
+ """
348
+ conditions = []
349
+ unet_skips = []
350
+ out_rgbs = []
351
+
352
+ # encoder
353
+ feat = self.conv_body_first(x)
354
+ for i in range(self.log_size - 2):
355
+ feat = self.conv_body_down[i](feat)
356
+ unet_skips.insert(0, feat)
357
+
358
+ feat = self.final_conv(feat)
359
+
360
+ # style code
361
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
362
+ if self.different_w:
363
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
364
+
365
+ # decode
366
+ for i in range(self.log_size - 2):
367
+ # add unet skip
368
+ feat = feat + unet_skips[i]
369
+ # ResUpLayer
370
+ feat = self.conv_body_up[i](feat)
371
+ # generate scale and shift for SFT layers
372
+ scale = self.condition_scale[i](feat)
373
+ conditions.append(scale.clone())
374
+ shift = self.condition_shift[i](feat)
375
+ conditions.append(shift.clone())
376
+ # generate rgb images
377
+ if return_rgb:
378
+ out_rgbs.append(self.toRGB[i](feat))
379
+
380
+ # decoder
381
+ image, _ = self.stylegan_decoder(
382
+ [style_code],
383
+ conditions,
384
+ return_latents=return_latents,
385
+ input_is_latent=self.input_is_latent,
386
+ randomize_noise=randomize_noise,
387
+ )
388
+
389
+ return image, out_rgbs
comfy_extras/chainner_models/architecture/face/gfpganv1_arch.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ import math
4
+ import random
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .fused_act import FusedLeakyReLU
11
+ from .stylegan2_arch import (
12
+ ConvLayer,
13
+ EqualConv2d,
14
+ EqualLinear,
15
+ ResBlock,
16
+ ScaledLeakyReLU,
17
+ StyleGAN2Generator,
18
+ )
19
+
20
+
21
+ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
22
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
23
+ Args:
24
+ out_size (int): The spatial size of outputs.
25
+ num_style_feat (int): Channel number of style features. Default: 512.
26
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
27
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
28
+ resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
29
+ applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
30
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
31
+ narrow (float): The narrow ratio for channels. Default: 1.
32
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ out_size,
38
+ num_style_feat=512,
39
+ num_mlp=8,
40
+ channel_multiplier=2,
41
+ resample_kernel=(1, 3, 3, 1),
42
+ lr_mlp=0.01,
43
+ narrow=1,
44
+ sft_half=False,
45
+ ):
46
+ super(StyleGAN2GeneratorSFT, self).__init__(
47
+ out_size,
48
+ num_style_feat=num_style_feat,
49
+ num_mlp=num_mlp,
50
+ channel_multiplier=channel_multiplier,
51
+ resample_kernel=resample_kernel,
52
+ lr_mlp=lr_mlp,
53
+ narrow=narrow,
54
+ )
55
+ self.sft_half = sft_half
56
+
57
+ def forward(
58
+ self,
59
+ styles,
60
+ conditions,
61
+ input_is_latent=False,
62
+ noise=None,
63
+ randomize_noise=True,
64
+ truncation=1,
65
+ truncation_latent=None,
66
+ inject_index=None,
67
+ return_latents=False,
68
+ ):
69
+ """Forward function for StyleGAN2GeneratorSFT.
70
+ Args:
71
+ styles (list[Tensor]): Sample codes of styles.
72
+ conditions (list[Tensor]): SFT conditions to generators.
73
+ input_is_latent (bool): Whether input is latent style. Default: False.
74
+ noise (Tensor | None): Input noise or None. Default: None.
75
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
76
+ truncation (float): The truncation ratio. Default: 1.
77
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
78
+ inject_index (int | None): The injection index for mixing noise. Default: None.
79
+ return_latents (bool): Whether to return style latents. Default: False.
80
+ """
81
+ # style codes -> latents with Style MLP layer
82
+ if not input_is_latent:
83
+ styles = [self.style_mlp(s) for s in styles]
84
+ # noises
85
+ if noise is None:
86
+ if randomize_noise:
87
+ noise = [None] * self.num_layers # for each style conv layer
88
+ else: # use the stored noise
89
+ noise = [
90
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
91
+ ]
92
+ # style truncation
93
+ if truncation < 1:
94
+ style_truncation = []
95
+ for style in styles:
96
+ style_truncation.append(
97
+ truncation_latent + truncation * (style - truncation_latent)
98
+ )
99
+ styles = style_truncation
100
+ # get style latents with injection
101
+ if len(styles) == 1:
102
+ inject_index = self.num_latent
103
+
104
+ if styles[0].ndim < 3:
105
+ # repeat latent code for all the layers
106
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
107
+ else: # used for encoder with different latent code for each layer
108
+ latent = styles[0]
109
+ elif len(styles) == 2: # mixing noises
110
+ if inject_index is None:
111
+ inject_index = random.randint(1, self.num_latent - 1)
112
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
113
+ latent2 = (
114
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
115
+ )
116
+ latent = torch.cat([latent1, latent2], 1)
117
+
118
+ # main generation
119
+ out = self.constant_input(latent.shape[0])
120
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
121
+ skip = self.to_rgb1(out, latent[:, 1])
122
+
123
+ i = 1
124
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
125
+ self.style_convs[::2],
126
+ self.style_convs[1::2],
127
+ noise[1::2],
128
+ noise[2::2],
129
+ self.to_rgbs,
130
+ ):
131
+ out = conv1(out, latent[:, i], noise=noise1)
132
+
133
+ # the conditions may have fewer levels
134
+ if i < len(conditions):
135
+ # SFT part to combine the conditions
136
+ if self.sft_half: # only apply SFT to half of the channels
137
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
138
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
139
+ out = torch.cat([out_same, out_sft], dim=1)
140
+ else: # apply SFT to all the channels
141
+ out = out * conditions[i - 1] + conditions[i]
142
+
143
+ out = conv2(out, latent[:, i + 1], noise=noise2)
144
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
145
+ i += 2
146
+
147
+ image = skip
148
+
149
+ if return_latents:
150
+ return image, latent
151
+ else:
152
+ return image, None
153
+
154
+
155
+ class ConvUpLayer(nn.Module):
156
+ """Convolutional upsampling layer. It uses bilinear upsampler + Conv.
157
+ Args:
158
+ in_channels (int): Channel number of the input.
159
+ out_channels (int): Channel number of the output.
160
+ kernel_size (int): Size of the convolving kernel.
161
+ stride (int): Stride of the convolution. Default: 1
162
+ padding (int): Zero-padding added to both sides of the input. Default: 0.
163
+ bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
164
+ bias_init_val (float): Bias initialized value. Default: 0.
165
+ activate (bool): Whether use activateion. Default: True.
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ in_channels,
171
+ out_channels,
172
+ kernel_size,
173
+ stride=1,
174
+ padding=0,
175
+ bias=True,
176
+ bias_init_val=0,
177
+ activate=True,
178
+ ):
179
+ super(ConvUpLayer, self).__init__()
180
+ self.in_channels = in_channels
181
+ self.out_channels = out_channels
182
+ self.kernel_size = kernel_size
183
+ self.stride = stride
184
+ self.padding = padding
185
+ # self.scale is used to scale the convolution weights, which is related to the common initializations.
186
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
187
+
188
+ self.weight = nn.Parameter(
189
+ torch.randn(out_channels, in_channels, kernel_size, kernel_size)
190
+ )
191
+
192
+ if bias and not activate:
193
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
194
+ else:
195
+ self.register_parameter("bias", None)
196
+
197
+ # activation
198
+ if activate:
199
+ if bias:
200
+ self.activation = FusedLeakyReLU(out_channels)
201
+ else:
202
+ self.activation = ScaledLeakyReLU(0.2)
203
+ else:
204
+ self.activation = None
205
+
206
+ def forward(self, x):
207
+ # bilinear upsample
208
+ out = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
209
+ # conv
210
+ out = F.conv2d(
211
+ out,
212
+ self.weight * self.scale,
213
+ bias=self.bias,
214
+ stride=self.stride,
215
+ padding=self.padding,
216
+ )
217
+ # activation
218
+ if self.activation is not None:
219
+ out = self.activation(out)
220
+ return out
221
+
222
+
223
+ class ResUpBlock(nn.Module):
224
+ """Residual block with upsampling.
225
+ Args:
226
+ in_channels (int): Channel number of the input.
227
+ out_channels (int): Channel number of the output.
228
+ """
229
+
230
+ def __init__(self, in_channels, out_channels):
231
+ super(ResUpBlock, self).__init__()
232
+
233
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
234
+ self.conv2 = ConvUpLayer(
235
+ in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True
236
+ )
237
+ self.skip = ConvUpLayer(
238
+ in_channels, out_channels, 1, bias=False, activate=False
239
+ )
240
+
241
+ def forward(self, x):
242
+ out = self.conv1(x)
243
+ out = self.conv2(out)
244
+ skip = self.skip(x)
245
+ out = (out + skip) / math.sqrt(2)
246
+ return out
247
+
248
+
249
+ class GFPGANv1(nn.Module):
250
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
251
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
252
+ Args:
253
+ out_size (int): The spatial size of outputs.
254
+ num_style_feat (int): Channel number of style features. Default: 512.
255
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
256
+ resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
257
+ applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
258
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
259
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
260
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
261
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
262
+ input_is_latent (bool): Whether input is latent style. Default: False.
263
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
264
+ narrow (float): The narrow ratio for channels. Default: 1.
265
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
266
+ """
267
+
268
+ def __init__(
269
+ self,
270
+ out_size,
271
+ num_style_feat=512,
272
+ channel_multiplier=1,
273
+ resample_kernel=(1, 3, 3, 1),
274
+ decoder_load_path=None,
275
+ fix_decoder=True,
276
+ # for stylegan decoder
277
+ num_mlp=8,
278
+ lr_mlp=0.01,
279
+ input_is_latent=False,
280
+ different_w=False,
281
+ narrow=1,
282
+ sft_half=False,
283
+ ):
284
+ super(GFPGANv1, self).__init__()
285
+ self.input_is_latent = input_is_latent
286
+ self.different_w = different_w
287
+ self.num_style_feat = num_style_feat
288
+
289
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
290
+ channels = {
291
+ "4": int(512 * unet_narrow),
292
+ "8": int(512 * unet_narrow),
293
+ "16": int(512 * unet_narrow),
294
+ "32": int(512 * unet_narrow),
295
+ "64": int(256 * channel_multiplier * unet_narrow),
296
+ "128": int(128 * channel_multiplier * unet_narrow),
297
+ "256": int(64 * channel_multiplier * unet_narrow),
298
+ "512": int(32 * channel_multiplier * unet_narrow),
299
+ "1024": int(16 * channel_multiplier * unet_narrow),
300
+ }
301
+
302
+ self.log_size = int(math.log(out_size, 2))
303
+ first_out_size = 2 ** (int(math.log(out_size, 2)))
304
+
305
+ self.conv_body_first = ConvLayer(
306
+ 3, channels[f"{first_out_size}"], 1, bias=True, activate=True
307
+ )
308
+
309
+ # downsample
310
+ in_channels = channels[f"{first_out_size}"]
311
+ self.conv_body_down = nn.ModuleList()
312
+ for i in range(self.log_size, 2, -1):
313
+ out_channels = channels[f"{2**(i - 1)}"]
314
+ self.conv_body_down.append(
315
+ ResBlock(in_channels, out_channels, resample_kernel)
316
+ )
317
+ in_channels = out_channels
318
+
319
+ self.final_conv = ConvLayer(
320
+ in_channels, channels["4"], 3, bias=True, activate=True
321
+ )
322
+
323
+ # upsample
324
+ in_channels = channels["4"]
325
+ self.conv_body_up = nn.ModuleList()
326
+ for i in range(3, self.log_size + 1):
327
+ out_channels = channels[f"{2**i}"]
328
+ self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
329
+ in_channels = out_channels
330
+
331
+ # to RGB
332
+ self.toRGB = nn.ModuleList()
333
+ for i in range(3, self.log_size + 1):
334
+ self.toRGB.append(
335
+ EqualConv2d(
336
+ channels[f"{2**i}"],
337
+ 3,
338
+ 1,
339
+ stride=1,
340
+ padding=0,
341
+ bias=True,
342
+ bias_init_val=0,
343
+ )
344
+ )
345
+
346
+ if different_w:
347
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
348
+ else:
349
+ linear_out_channel = num_style_feat
350
+
351
+ self.final_linear = EqualLinear(
352
+ channels["4"] * 4 * 4,
353
+ linear_out_channel,
354
+ bias=True,
355
+ bias_init_val=0,
356
+ lr_mul=1,
357
+ activation=None,
358
+ )
359
+
360
+ # the decoder: stylegan2 generator with SFT modulations
361
+ self.stylegan_decoder = StyleGAN2GeneratorSFT(
362
+ out_size=out_size,
363
+ num_style_feat=num_style_feat,
364
+ num_mlp=num_mlp,
365
+ channel_multiplier=channel_multiplier,
366
+ resample_kernel=resample_kernel,
367
+ lr_mlp=lr_mlp,
368
+ narrow=narrow,
369
+ sft_half=sft_half,
370
+ )
371
+
372
+ # load pre-trained stylegan2 model if necessary
373
+ if decoder_load_path:
374
+ self.stylegan_decoder.load_state_dict(
375
+ torch.load(
376
+ decoder_load_path, map_location=lambda storage, loc: storage
377
+ )["params_ema"]
378
+ )
379
+ # fix decoder without updating params
380
+ if fix_decoder:
381
+ for _, param in self.stylegan_decoder.named_parameters():
382
+ param.requires_grad = False
383
+
384
+ # for SFT modulations (scale and shift)
385
+ self.condition_scale = nn.ModuleList()
386
+ self.condition_shift = nn.ModuleList()
387
+ for i in range(3, self.log_size + 1):
388
+ out_channels = channels[f"{2**i}"]
389
+ if sft_half:
390
+ sft_out_channels = out_channels
391
+ else:
392
+ sft_out_channels = out_channels * 2
393
+ self.condition_scale.append(
394
+ nn.Sequential(
395
+ EqualConv2d(
396
+ out_channels,
397
+ out_channels,
398
+ 3,
399
+ stride=1,
400
+ padding=1,
401
+ bias=True,
402
+ bias_init_val=0,
403
+ ),
404
+ ScaledLeakyReLU(0.2),
405
+ EqualConv2d(
406
+ out_channels,
407
+ sft_out_channels,
408
+ 3,
409
+ stride=1,
410
+ padding=1,
411
+ bias=True,
412
+ bias_init_val=1,
413
+ ),
414
+ )
415
+ )
416
+ self.condition_shift.append(
417
+ nn.Sequential(
418
+ EqualConv2d(
419
+ out_channels,
420
+ out_channels,
421
+ 3,
422
+ stride=1,
423
+ padding=1,
424
+ bias=True,
425
+ bias_init_val=0,
426
+ ),
427
+ ScaledLeakyReLU(0.2),
428
+ EqualConv2d(
429
+ out_channels,
430
+ sft_out_channels,
431
+ 3,
432
+ stride=1,
433
+ padding=1,
434
+ bias=True,
435
+ bias_init_val=0,
436
+ ),
437
+ )
438
+ )
439
+
440
+ def forward(
441
+ self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
442
+ ):
443
+ """Forward function for GFPGANv1.
444
+ Args:
445
+ x (Tensor): Input images.
446
+ return_latents (bool): Whether to return style latents. Default: False.
447
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
448
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
449
+ """
450
+ conditions = []
451
+ unet_skips = []
452
+ out_rgbs = []
453
+
454
+ # encoder
455
+ feat = self.conv_body_first(x)
456
+ for i in range(self.log_size - 2):
457
+ feat = self.conv_body_down[i](feat)
458
+ unet_skips.insert(0, feat)
459
+
460
+ feat = self.final_conv(feat)
461
+
462
+ # style code
463
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
464
+ if self.different_w:
465
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
466
+
467
+ # decode
468
+ for i in range(self.log_size - 2):
469
+ # add unet skip
470
+ feat = feat + unet_skips[i]
471
+ # ResUpLayer
472
+ feat = self.conv_body_up[i](feat)
473
+ # generate scale and shift for SFT layers
474
+ scale = self.condition_scale[i](feat)
475
+ conditions.append(scale.clone())
476
+ shift = self.condition_shift[i](feat)
477
+ conditions.append(shift.clone())
478
+ # generate rgb images
479
+ if return_rgb:
480
+ out_rgbs.append(self.toRGB[i](feat))
481
+
482
+ # decoder
483
+ image, _ = self.stylegan_decoder(
484
+ [style_code],
485
+ conditions,
486
+ return_latents=return_latents,
487
+ input_is_latent=self.input_is_latent,
488
+ randomize_noise=randomize_noise,
489
+ )
490
+
491
+ return image, out_rgbs
492
+
493
+
494
+ class FacialComponentDiscriminator(nn.Module):
495
+ """Facial component (eyes, mouth, noise) discriminator used in GFPGAN."""
496
+
497
+ def __init__(self):
498
+ super(FacialComponentDiscriminator, self).__init__()
499
+ # It now uses a VGG-style architectrue with fixed model size
500
+ self.conv1 = ConvLayer(
501
+ 3,
502
+ 64,
503
+ 3,
504
+ downsample=False,
505
+ resample_kernel=(1, 3, 3, 1),
506
+ bias=True,
507
+ activate=True,
508
+ )
509
+ self.conv2 = ConvLayer(
510
+ 64,
511
+ 128,
512
+ 3,
513
+ downsample=True,
514
+ resample_kernel=(1, 3, 3, 1),
515
+ bias=True,
516
+ activate=True,
517
+ )
518
+ self.conv3 = ConvLayer(
519
+ 128,
520
+ 128,
521
+ 3,
522
+ downsample=False,
523
+ resample_kernel=(1, 3, 3, 1),
524
+ bias=True,
525
+ activate=True,
526
+ )
527
+ self.conv4 = ConvLayer(
528
+ 128,
529
+ 256,
530
+ 3,
531
+ downsample=True,
532
+ resample_kernel=(1, 3, 3, 1),
533
+ bias=True,
534
+ activate=True,
535
+ )
536
+ self.conv5 = ConvLayer(
537
+ 256,
538
+ 256,
539
+ 3,
540
+ downsample=False,
541
+ resample_kernel=(1, 3, 3, 1),
542
+ bias=True,
543
+ activate=True,
544
+ )
545
+ self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
546
+
547
+ def forward(self, x, return_feats=False, **kwargs):
548
+ """Forward function for FacialComponentDiscriminator.
549
+ Args:
550
+ x (Tensor): Input images.
551
+ return_feats (bool): Whether to return intermediate features. Default: False.
552
+ """
553
+ feat = self.conv1(x)
554
+ feat = self.conv3(self.conv2(feat))
555
+ rlt_feats = []
556
+ if return_feats:
557
+ rlt_feats.append(feat.clone())
558
+ feat = self.conv5(self.conv4(feat))
559
+ if return_feats:
560
+ rlt_feats.append(feat.clone())
561
+ out = self.final_conv(feat)
562
+
563
+ if return_feats:
564
+ return out, rlt_feats
565
+ else:
566
+ return out, None
comfy_extras/chainner_models/architecture/face/gfpganv1_clean_arch.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ import math
4
+ import random
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .stylegan2_clean_arch import StyleGAN2GeneratorClean
11
+
12
+
13
+ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
14
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
15
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
16
+ Args:
17
+ out_size (int): The spatial size of outputs.
18
+ num_style_feat (int): Channel number of style features. Default: 512.
19
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
20
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
21
+ narrow (float): The narrow ratio for channels. Default: 1.
22
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ out_size,
28
+ num_style_feat=512,
29
+ num_mlp=8,
30
+ channel_multiplier=2,
31
+ narrow=1,
32
+ sft_half=False,
33
+ ):
34
+ super(StyleGAN2GeneratorCSFT, self).__init__(
35
+ out_size,
36
+ num_style_feat=num_style_feat,
37
+ num_mlp=num_mlp,
38
+ channel_multiplier=channel_multiplier,
39
+ narrow=narrow,
40
+ )
41
+ self.sft_half = sft_half
42
+
43
+ def forward(
44
+ self,
45
+ styles,
46
+ conditions,
47
+ input_is_latent=False,
48
+ noise=None,
49
+ randomize_noise=True,
50
+ truncation=1,
51
+ truncation_latent=None,
52
+ inject_index=None,
53
+ return_latents=False,
54
+ ):
55
+ """Forward function for StyleGAN2GeneratorCSFT.
56
+ Args:
57
+ styles (list[Tensor]): Sample codes of styles.
58
+ conditions (list[Tensor]): SFT conditions to generators.
59
+ input_is_latent (bool): Whether input is latent style. Default: False.
60
+ noise (Tensor | None): Input noise or None. Default: None.
61
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
62
+ truncation (float): The truncation ratio. Default: 1.
63
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
64
+ inject_index (int | None): The injection index for mixing noise. Default: None.
65
+ return_latents (bool): Whether to return style latents. Default: False.
66
+ """
67
+ # style codes -> latents with Style MLP layer
68
+ if not input_is_latent:
69
+ styles = [self.style_mlp(s) for s in styles]
70
+ # noises
71
+ if noise is None:
72
+ if randomize_noise:
73
+ noise = [None] * self.num_layers # for each style conv layer
74
+ else: # use the stored noise
75
+ noise = [
76
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
77
+ ]
78
+ # style truncation
79
+ if truncation < 1:
80
+ style_truncation = []
81
+ for style in styles:
82
+ style_truncation.append(
83
+ truncation_latent + truncation * (style - truncation_latent)
84
+ )
85
+ styles = style_truncation
86
+ # get style latents with injection
87
+ if len(styles) == 1:
88
+ inject_index = self.num_latent
89
+
90
+ if styles[0].ndim < 3:
91
+ # repeat latent code for all the layers
92
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
93
+ else: # used for encoder with different latent code for each layer
94
+ latent = styles[0]
95
+ elif len(styles) == 2: # mixing noises
96
+ if inject_index is None:
97
+ inject_index = random.randint(1, self.num_latent - 1)
98
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
99
+ latent2 = (
100
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
101
+ )
102
+ latent = torch.cat([latent1, latent2], 1)
103
+
104
+ # main generation
105
+ out = self.constant_input(latent.shape[0])
106
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
107
+ skip = self.to_rgb1(out, latent[:, 1])
108
+
109
+ i = 1
110
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
111
+ self.style_convs[::2],
112
+ self.style_convs[1::2],
113
+ noise[1::2],
114
+ noise[2::2],
115
+ self.to_rgbs,
116
+ ):
117
+ out = conv1(out, latent[:, i], noise=noise1)
118
+
119
+ # the conditions may have fewer levels
120
+ if i < len(conditions):
121
+ # SFT part to combine the conditions
122
+ if self.sft_half: # only apply SFT to half of the channels
123
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
124
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
125
+ out = torch.cat([out_same, out_sft], dim=1)
126
+ else: # apply SFT to all the channels
127
+ out = out * conditions[i - 1] + conditions[i]
128
+
129
+ out = conv2(out, latent[:, i + 1], noise=noise2)
130
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
131
+ i += 2
132
+
133
+ image = skip
134
+
135
+ if return_latents:
136
+ return image, latent
137
+ else:
138
+ return image, None
139
+
140
+
141
+ class ResBlock(nn.Module):
142
+ """Residual block with bilinear upsampling/downsampling.
143
+ Args:
144
+ in_channels (int): Channel number of the input.
145
+ out_channels (int): Channel number of the output.
146
+ mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
147
+ """
148
+
149
+ def __init__(self, in_channels, out_channels, mode="down"):
150
+ super(ResBlock, self).__init__()
151
+
152
+ self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
153
+ self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
154
+ self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
155
+ if mode == "down":
156
+ self.scale_factor = 0.5
157
+ elif mode == "up":
158
+ self.scale_factor = 2
159
+
160
+ def forward(self, x):
161
+ out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
162
+ # upsample/downsample
163
+ out = F.interpolate(
164
+ out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
165
+ )
166
+ out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
167
+ # skip
168
+ x = F.interpolate(
169
+ x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
170
+ )
171
+ skip = self.skip(x)
172
+ out = out + skip
173
+ return out
174
+
175
+
176
+ class GFPGANv1Clean(nn.Module):
177
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
178
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
179
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
180
+ Args:
181
+ out_size (int): The spatial size of outputs.
182
+ num_style_feat (int): Channel number of style features. Default: 512.
183
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
184
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
185
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
186
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
187
+ input_is_latent (bool): Whether input is latent style. Default: False.
188
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
189
+ narrow (float): The narrow ratio for channels. Default: 1.
190
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ state_dict,
196
+ ):
197
+ super(GFPGANv1Clean, self).__init__()
198
+
199
+ out_size = 512
200
+ num_style_feat = 512
201
+ channel_multiplier = 2
202
+ decoder_load_path = None
203
+ fix_decoder = False
204
+ num_mlp = 8
205
+ input_is_latent = True
206
+ different_w = True
207
+ narrow = 1
208
+ sft_half = True
209
+
210
+ self.model_arch = "GFPGAN"
211
+ self.sub_type = "Face SR"
212
+ self.scale = 8
213
+ self.in_nc = 3
214
+ self.out_nc = 3
215
+ self.state = state_dict
216
+
217
+ self.supports_fp16 = False
218
+ self.supports_bf16 = True
219
+ self.min_size_restriction = 512
220
+
221
+ self.input_is_latent = input_is_latent
222
+ self.different_w = different_w
223
+ self.num_style_feat = num_style_feat
224
+
225
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
226
+ channels = {
227
+ "4": int(512 * unet_narrow),
228
+ "8": int(512 * unet_narrow),
229
+ "16": int(512 * unet_narrow),
230
+ "32": int(512 * unet_narrow),
231
+ "64": int(256 * channel_multiplier * unet_narrow),
232
+ "128": int(128 * channel_multiplier * unet_narrow),
233
+ "256": int(64 * channel_multiplier * unet_narrow),
234
+ "512": int(32 * channel_multiplier * unet_narrow),
235
+ "1024": int(16 * channel_multiplier * unet_narrow),
236
+ }
237
+
238
+ self.log_size = int(math.log(out_size, 2))
239
+ first_out_size = 2 ** (int(math.log(out_size, 2)))
240
+
241
+ self.conv_body_first = nn.Conv2d(3, channels[f"{first_out_size}"], 1)
242
+
243
+ # downsample
244
+ in_channels = channels[f"{first_out_size}"]
245
+ self.conv_body_down = nn.ModuleList()
246
+ for i in range(self.log_size, 2, -1):
247
+ out_channels = channels[f"{2**(i - 1)}"]
248
+ self.conv_body_down.append(ResBlock(in_channels, out_channels, mode="down"))
249
+ in_channels = out_channels
250
+
251
+ self.final_conv = nn.Conv2d(in_channels, channels["4"], 3, 1, 1)
252
+
253
+ # upsample
254
+ in_channels = channels["4"]
255
+ self.conv_body_up = nn.ModuleList()
256
+ for i in range(3, self.log_size + 1):
257
+ out_channels = channels[f"{2**i}"]
258
+ self.conv_body_up.append(ResBlock(in_channels, out_channels, mode="up"))
259
+ in_channels = out_channels
260
+
261
+ # to RGB
262
+ self.toRGB = nn.ModuleList()
263
+ for i in range(3, self.log_size + 1):
264
+ self.toRGB.append(nn.Conv2d(channels[f"{2**i}"], 3, 1))
265
+
266
+ if different_w:
267
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
268
+ else:
269
+ linear_out_channel = num_style_feat
270
+
271
+ self.final_linear = nn.Linear(channels["4"] * 4 * 4, linear_out_channel)
272
+
273
+ # the decoder: stylegan2 generator with SFT modulations
274
+ self.stylegan_decoder = StyleGAN2GeneratorCSFT(
275
+ out_size=out_size,
276
+ num_style_feat=num_style_feat,
277
+ num_mlp=num_mlp,
278
+ channel_multiplier=channel_multiplier,
279
+ narrow=narrow,
280
+ sft_half=sft_half,
281
+ )
282
+
283
+ # load pre-trained stylegan2 model if necessary
284
+ if decoder_load_path:
285
+ self.stylegan_decoder.load_state_dict(
286
+ torch.load(
287
+ decoder_load_path, map_location=lambda storage, loc: storage
288
+ )["params_ema"]
289
+ )
290
+ # fix decoder without updating params
291
+ if fix_decoder:
292
+ for _, param in self.stylegan_decoder.named_parameters():
293
+ param.requires_grad = False
294
+
295
+ # for SFT modulations (scale and shift)
296
+ self.condition_scale = nn.ModuleList()
297
+ self.condition_shift = nn.ModuleList()
298
+ for i in range(3, self.log_size + 1):
299
+ out_channels = channels[f"{2**i}"]
300
+ if sft_half:
301
+ sft_out_channels = out_channels
302
+ else:
303
+ sft_out_channels = out_channels * 2
304
+ self.condition_scale.append(
305
+ nn.Sequential(
306
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
307
+ nn.LeakyReLU(0.2, True),
308
+ nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1),
309
+ )
310
+ )
311
+ self.condition_shift.append(
312
+ nn.Sequential(
313
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
314
+ nn.LeakyReLU(0.2, True),
315
+ nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1),
316
+ )
317
+ )
318
+ self.load_state_dict(state_dict)
319
+
320
+ def forward(
321
+ self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
322
+ ):
323
+ """Forward function for GFPGANv1Clean.
324
+ Args:
325
+ x (Tensor): Input images.
326
+ return_latents (bool): Whether to return style latents. Default: False.
327
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
328
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
329
+ """
330
+ conditions = []
331
+ unet_skips = []
332
+ out_rgbs = []
333
+
334
+ # encoder
335
+ feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
336
+ for i in range(self.log_size - 2):
337
+ feat = self.conv_body_down[i](feat)
338
+ unet_skips.insert(0, feat)
339
+ feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
340
+
341
+ # style code
342
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
343
+ if self.different_w:
344
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
345
+
346
+ # decode
347
+ for i in range(self.log_size - 2):
348
+ # add unet skip
349
+ feat = feat + unet_skips[i]
350
+ # ResUpLayer
351
+ feat = self.conv_body_up[i](feat)
352
+ # generate scale and shift for SFT layers
353
+ scale = self.condition_scale[i](feat)
354
+ conditions.append(scale.clone())
355
+ shift = self.condition_shift[i](feat)
356
+ conditions.append(shift.clone())
357
+ # generate rgb images
358
+ if return_rgb:
359
+ out_rgbs.append(self.toRGB[i](feat))
360
+
361
+ # decoder
362
+ image, _ = self.stylegan_decoder(
363
+ [style_code],
364
+ conditions,
365
+ return_latents=return_latents,
366
+ input_is_latent=self.input_is_latent,
367
+ randomize_noise=randomize_noise,
368
+ )
369
+
370
+ return image, out_rgbs
comfy_extras/chainner_models/architecture/face/restoreformer_arch.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ """Modified from https://github.com/wzhouxiff/RestoreFormer
4
+ """
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class VectorQuantizer(nn.Module):
12
+ """
13
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
14
+ ____________________________________________
15
+ Discretization bottleneck part of the VQ-VAE.
16
+ Inputs:
17
+ - n_e : number of embeddings
18
+ - e_dim : dimension of embedding
19
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
20
+ _____________________________________________
21
+ """
22
+
23
+ def __init__(self, n_e, e_dim, beta):
24
+ super(VectorQuantizer, self).__init__()
25
+ self.n_e = n_e
26
+ self.e_dim = e_dim
27
+ self.beta = beta
28
+
29
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
30
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
31
+
32
+ def forward(self, z):
33
+ """
34
+ Inputs the output of the encoder network z and maps it to a discrete
35
+ one-hot vector that is the index of the closest embedding vector e_j
36
+ z (continuous) -> z_q (discrete)
37
+ z.shape = (batch, channel, height, width)
38
+ quantization pipeline:
39
+ 1. get encoder input (B,C,H,W)
40
+ 2. flatten input to (B*H*W,C)
41
+ """
42
+ # reshape z -> (batch, height, width, channel) and flatten
43
+ z = z.permute(0, 2, 3, 1).contiguous()
44
+ z_flattened = z.view(-1, self.e_dim)
45
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
46
+
47
+ d = (
48
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
49
+ + torch.sum(self.embedding.weight**2, dim=1)
50
+ - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
51
+ )
52
+
53
+ # could possible replace this here
54
+ # #\start...
55
+ # find closest encodings
56
+
57
+ min_value, min_encoding_indices = torch.min(d, dim=1)
58
+
59
+ min_encoding_indices = min_encoding_indices.unsqueeze(1)
60
+
61
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
62
+ min_encodings.scatter_(1, min_encoding_indices, 1)
63
+
64
+ # dtype min encodings: torch.float32
65
+ # min_encodings shape: torch.Size([2048, 512])
66
+ # min_encoding_indices.shape: torch.Size([2048, 1])
67
+
68
+ # get quantized latent vectors
69
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
70
+ # .........\end
71
+
72
+ # with:
73
+ # .........\start
74
+ # min_encoding_indices = torch.argmin(d, dim=1)
75
+ # z_q = self.embedding(min_encoding_indices)
76
+ # ......\end......... (TODO)
77
+
78
+ # compute loss for embedding
79
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
80
+ (z_q - z.detach()) ** 2
81
+ )
82
+
83
+ # preserve gradients
84
+ z_q = z + (z_q - z).detach()
85
+
86
+ # perplexity
87
+
88
+ e_mean = torch.mean(min_encodings, dim=0)
89
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
90
+
91
+ # reshape back to match original input shape
92
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
93
+
94
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
95
+
96
+ def get_codebook_entry(self, indices, shape):
97
+ # shape specifying (batch, height, width, channel)
98
+ # TODO: check for more easy handling with nn.Embedding
99
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
100
+ min_encodings.scatter_(1, indices[:, None], 1)
101
+
102
+ # get quantized latent vectors
103
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
104
+
105
+ if shape is not None:
106
+ z_q = z_q.view(shape)
107
+
108
+ # reshape back to match original input shape
109
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
110
+
111
+ return z_q
112
+
113
+
114
+ # pytorch_diffusion + derived encoder decoder
115
+ def nonlinearity(x):
116
+ # swish
117
+ return x * torch.sigmoid(x)
118
+
119
+
120
+ def Normalize(in_channels):
121
+ return torch.nn.GroupNorm(
122
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
123
+ )
124
+
125
+
126
+ class Upsample(nn.Module):
127
+ def __init__(self, in_channels, with_conv):
128
+ super().__init__()
129
+ self.with_conv = with_conv
130
+ if self.with_conv:
131
+ self.conv = torch.nn.Conv2d(
132
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
133
+ )
134
+
135
+ def forward(self, x):
136
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
137
+ if self.with_conv:
138
+ x = self.conv(x)
139
+ return x
140
+
141
+
142
+ class Downsample(nn.Module):
143
+ def __init__(self, in_channels, with_conv):
144
+ super().__init__()
145
+ self.with_conv = with_conv
146
+ if self.with_conv:
147
+ # no asymmetric padding in torch conv, must do it ourselves
148
+ self.conv = torch.nn.Conv2d(
149
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
150
+ )
151
+
152
+ def forward(self, x):
153
+ if self.with_conv:
154
+ pad = (0, 1, 0, 1)
155
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
156
+ x = self.conv(x)
157
+ else:
158
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
159
+ return x
160
+
161
+
162
+ class ResnetBlock(nn.Module):
163
+ def __init__(
164
+ self,
165
+ *,
166
+ in_channels,
167
+ out_channels=None,
168
+ conv_shortcut=False,
169
+ dropout,
170
+ temb_channels=512
171
+ ):
172
+ super().__init__()
173
+ self.in_channels = in_channels
174
+ out_channels = in_channels if out_channels is None else out_channels
175
+ self.out_channels = out_channels
176
+ self.use_conv_shortcut = conv_shortcut
177
+
178
+ self.norm1 = Normalize(in_channels)
179
+ self.conv1 = torch.nn.Conv2d(
180
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
181
+ )
182
+ if temb_channels > 0:
183
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
184
+ self.norm2 = Normalize(out_channels)
185
+ self.dropout = torch.nn.Dropout(dropout)
186
+ self.conv2 = torch.nn.Conv2d(
187
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
188
+ )
189
+ if self.in_channels != self.out_channels:
190
+ if self.use_conv_shortcut:
191
+ self.conv_shortcut = torch.nn.Conv2d(
192
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
193
+ )
194
+ else:
195
+ self.nin_shortcut = torch.nn.Conv2d(
196
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
197
+ )
198
+
199
+ def forward(self, x, temb):
200
+ h = x
201
+ h = self.norm1(h)
202
+ h = nonlinearity(h)
203
+ h = self.conv1(h)
204
+
205
+ if temb is not None:
206
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
207
+
208
+ h = self.norm2(h)
209
+ h = nonlinearity(h)
210
+ h = self.dropout(h)
211
+ h = self.conv2(h)
212
+
213
+ if self.in_channels != self.out_channels:
214
+ if self.use_conv_shortcut:
215
+ x = self.conv_shortcut(x)
216
+ else:
217
+ x = self.nin_shortcut(x)
218
+
219
+ return x + h
220
+
221
+
222
+ class MultiHeadAttnBlock(nn.Module):
223
+ def __init__(self, in_channels, head_size=1):
224
+ super().__init__()
225
+ self.in_channels = in_channels
226
+ self.head_size = head_size
227
+ self.att_size = in_channels // head_size
228
+ assert (
229
+ in_channels % head_size == 0
230
+ ), "The size of head should be divided by the number of channels."
231
+
232
+ self.norm1 = Normalize(in_channels)
233
+ self.norm2 = Normalize(in_channels)
234
+
235
+ self.q = torch.nn.Conv2d(
236
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
237
+ )
238
+ self.k = torch.nn.Conv2d(
239
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
240
+ )
241
+ self.v = torch.nn.Conv2d(
242
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
243
+ )
244
+ self.proj_out = torch.nn.Conv2d(
245
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
246
+ )
247
+ self.num = 0
248
+
249
+ def forward(self, x, y=None):
250
+ h_ = x
251
+ h_ = self.norm1(h_)
252
+ if y is None:
253
+ y = h_
254
+ else:
255
+ y = self.norm2(y)
256
+
257
+ q = self.q(y)
258
+ k = self.k(h_)
259
+ v = self.v(h_)
260
+
261
+ # compute attention
262
+ b, c, h, w = q.shape
263
+ q = q.reshape(b, self.head_size, self.att_size, h * w)
264
+ q = q.permute(0, 3, 1, 2) # b, hw, head, att
265
+
266
+ k = k.reshape(b, self.head_size, self.att_size, h * w)
267
+ k = k.permute(0, 3, 1, 2)
268
+
269
+ v = v.reshape(b, self.head_size, self.att_size, h * w)
270
+ v = v.permute(0, 3, 1, 2)
271
+
272
+ q = q.transpose(1, 2)
273
+ v = v.transpose(1, 2)
274
+ k = k.transpose(1, 2).transpose(2, 3)
275
+
276
+ scale = int(self.att_size) ** (-0.5)
277
+ q.mul_(scale)
278
+ w_ = torch.matmul(q, k)
279
+ w_ = F.softmax(w_, dim=3)
280
+
281
+ w_ = w_.matmul(v)
282
+
283
+ w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
284
+ w_ = w_.view(b, h, w, -1)
285
+ w_ = w_.permute(0, 3, 1, 2)
286
+
287
+ w_ = self.proj_out(w_)
288
+
289
+ return x + w_
290
+
291
+
292
+ class MultiHeadEncoder(nn.Module):
293
+ def __init__(
294
+ self,
295
+ ch,
296
+ out_ch,
297
+ ch_mult=(1, 2, 4, 8),
298
+ num_res_blocks=2,
299
+ attn_resolutions=(16,),
300
+ dropout=0.0,
301
+ resamp_with_conv=True,
302
+ in_channels=3,
303
+ resolution=512,
304
+ z_channels=256,
305
+ double_z=True,
306
+ enable_mid=True,
307
+ head_size=1,
308
+ **ignore_kwargs
309
+ ):
310
+ super().__init__()
311
+ self.ch = ch
312
+ self.temb_ch = 0
313
+ self.num_resolutions = len(ch_mult)
314
+ self.num_res_blocks = num_res_blocks
315
+ self.resolution = resolution
316
+ self.in_channels = in_channels
317
+ self.enable_mid = enable_mid
318
+
319
+ # downsampling
320
+ self.conv_in = torch.nn.Conv2d(
321
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
322
+ )
323
+
324
+ curr_res = resolution
325
+ in_ch_mult = (1,) + tuple(ch_mult)
326
+ self.down = nn.ModuleList()
327
+ for i_level in range(self.num_resolutions):
328
+ block = nn.ModuleList()
329
+ attn = nn.ModuleList()
330
+ block_in = ch * in_ch_mult[i_level]
331
+ block_out = ch * ch_mult[i_level]
332
+ for i_block in range(self.num_res_blocks):
333
+ block.append(
334
+ ResnetBlock(
335
+ in_channels=block_in,
336
+ out_channels=block_out,
337
+ temb_channels=self.temb_ch,
338
+ dropout=dropout,
339
+ )
340
+ )
341
+ block_in = block_out
342
+ if curr_res in attn_resolutions:
343
+ attn.append(MultiHeadAttnBlock(block_in, head_size))
344
+ down = nn.Module()
345
+ down.block = block
346
+ down.attn = attn
347
+ if i_level != self.num_resolutions - 1:
348
+ down.downsample = Downsample(block_in, resamp_with_conv)
349
+ curr_res = curr_res // 2
350
+ self.down.append(down)
351
+
352
+ # middle
353
+ if self.enable_mid:
354
+ self.mid = nn.Module()
355
+ self.mid.block_1 = ResnetBlock(
356
+ in_channels=block_in,
357
+ out_channels=block_in,
358
+ temb_channels=self.temb_ch,
359
+ dropout=dropout,
360
+ )
361
+ self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
362
+ self.mid.block_2 = ResnetBlock(
363
+ in_channels=block_in,
364
+ out_channels=block_in,
365
+ temb_channels=self.temb_ch,
366
+ dropout=dropout,
367
+ )
368
+
369
+ # end
370
+ self.norm_out = Normalize(block_in)
371
+ self.conv_out = torch.nn.Conv2d(
372
+ block_in,
373
+ 2 * z_channels if double_z else z_channels,
374
+ kernel_size=3,
375
+ stride=1,
376
+ padding=1,
377
+ )
378
+
379
+ def forward(self, x):
380
+ hs = {}
381
+ # timestep embedding
382
+ temb = None
383
+
384
+ # downsampling
385
+ h = self.conv_in(x)
386
+ hs["in"] = h
387
+ for i_level in range(self.num_resolutions):
388
+ for i_block in range(self.num_res_blocks):
389
+ h = self.down[i_level].block[i_block](h, temb)
390
+ if len(self.down[i_level].attn) > 0:
391
+ h = self.down[i_level].attn[i_block](h)
392
+
393
+ if i_level != self.num_resolutions - 1:
394
+ # hs.append(h)
395
+ hs["block_" + str(i_level)] = h
396
+ h = self.down[i_level].downsample(h)
397
+
398
+ # middle
399
+ # h = hs[-1]
400
+ if self.enable_mid:
401
+ h = self.mid.block_1(h, temb)
402
+ hs["block_" + str(i_level) + "_atten"] = h
403
+ h = self.mid.attn_1(h)
404
+ h = self.mid.block_2(h, temb)
405
+ hs["mid_atten"] = h
406
+
407
+ # end
408
+ h = self.norm_out(h)
409
+ h = nonlinearity(h)
410
+ h = self.conv_out(h)
411
+ # hs.append(h)
412
+ hs["out"] = h
413
+
414
+ return hs
415
+
416
+
417
+ class MultiHeadDecoder(nn.Module):
418
+ def __init__(
419
+ self,
420
+ ch,
421
+ out_ch,
422
+ ch_mult=(1, 2, 4, 8),
423
+ num_res_blocks=2,
424
+ attn_resolutions=(16,),
425
+ dropout=0.0,
426
+ resamp_with_conv=True,
427
+ in_channels=3,
428
+ resolution=512,
429
+ z_channels=256,
430
+ give_pre_end=False,
431
+ enable_mid=True,
432
+ head_size=1,
433
+ **ignorekwargs
434
+ ):
435
+ super().__init__()
436
+ self.ch = ch
437
+ self.temb_ch = 0
438
+ self.num_resolutions = len(ch_mult)
439
+ self.num_res_blocks = num_res_blocks
440
+ self.resolution = resolution
441
+ self.in_channels = in_channels
442
+ self.give_pre_end = give_pre_end
443
+ self.enable_mid = enable_mid
444
+
445
+ # compute in_ch_mult, block_in and curr_res at lowest res
446
+ block_in = ch * ch_mult[self.num_resolutions - 1]
447
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
448
+ self.z_shape = (1, z_channels, curr_res, curr_res)
449
+ print(
450
+ "Working with z of shape {} = {} dimensions.".format(
451
+ self.z_shape, np.prod(self.z_shape)
452
+ )
453
+ )
454
+
455
+ # z to block_in
456
+ self.conv_in = torch.nn.Conv2d(
457
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
458
+ )
459
+
460
+ # middle
461
+ if self.enable_mid:
462
+ self.mid = nn.Module()
463
+ self.mid.block_1 = ResnetBlock(
464
+ in_channels=block_in,
465
+ out_channels=block_in,
466
+ temb_channels=self.temb_ch,
467
+ dropout=dropout,
468
+ )
469
+ self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
470
+ self.mid.block_2 = ResnetBlock(
471
+ in_channels=block_in,
472
+ out_channels=block_in,
473
+ temb_channels=self.temb_ch,
474
+ dropout=dropout,
475
+ )
476
+
477
+ # upsampling
478
+ self.up = nn.ModuleList()
479
+ for i_level in reversed(range(self.num_resolutions)):
480
+ block = nn.ModuleList()
481
+ attn = nn.ModuleList()
482
+ block_out = ch * ch_mult[i_level]
483
+ for i_block in range(self.num_res_blocks + 1):
484
+ block.append(
485
+ ResnetBlock(
486
+ in_channels=block_in,
487
+ out_channels=block_out,
488
+ temb_channels=self.temb_ch,
489
+ dropout=dropout,
490
+ )
491
+ )
492
+ block_in = block_out
493
+ if curr_res in attn_resolutions:
494
+ attn.append(MultiHeadAttnBlock(block_in, head_size))
495
+ up = nn.Module()
496
+ up.block = block
497
+ up.attn = attn
498
+ if i_level != 0:
499
+ up.upsample = Upsample(block_in, resamp_with_conv)
500
+ curr_res = curr_res * 2
501
+ self.up.insert(0, up) # prepend to get consistent order
502
+
503
+ # end
504
+ self.norm_out = Normalize(block_in)
505
+ self.conv_out = torch.nn.Conv2d(
506
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
507
+ )
508
+
509
+ def forward(self, z):
510
+ # assert z.shape[1:] == self.z_shape[1:]
511
+ self.last_z_shape = z.shape
512
+
513
+ # timestep embedding
514
+ temb = None
515
+
516
+ # z to block_in
517
+ h = self.conv_in(z)
518
+
519
+ # middle
520
+ if self.enable_mid:
521
+ h = self.mid.block_1(h, temb)
522
+ h = self.mid.attn_1(h)
523
+ h = self.mid.block_2(h, temb)
524
+
525
+ # upsampling
526
+ for i_level in reversed(range(self.num_resolutions)):
527
+ for i_block in range(self.num_res_blocks + 1):
528
+ h = self.up[i_level].block[i_block](h, temb)
529
+ if len(self.up[i_level].attn) > 0:
530
+ h = self.up[i_level].attn[i_block](h)
531
+ if i_level != 0:
532
+ h = self.up[i_level].upsample(h)
533
+
534
+ # end
535
+ if self.give_pre_end:
536
+ return h
537
+
538
+ h = self.norm_out(h)
539
+ h = nonlinearity(h)
540
+ h = self.conv_out(h)
541
+ return h
542
+
543
+
544
+ class MultiHeadDecoderTransformer(nn.Module):
545
+ def __init__(
546
+ self,
547
+ ch,
548
+ out_ch,
549
+ ch_mult=(1, 2, 4, 8),
550
+ num_res_blocks=2,
551
+ attn_resolutions=(16,),
552
+ dropout=0.0,
553
+ resamp_with_conv=True,
554
+ in_channels=3,
555
+ resolution=512,
556
+ z_channels=256,
557
+ give_pre_end=False,
558
+ enable_mid=True,
559
+ head_size=1,
560
+ **ignorekwargs
561
+ ):
562
+ super().__init__()
563
+ self.ch = ch
564
+ self.temb_ch = 0
565
+ self.num_resolutions = len(ch_mult)
566
+ self.num_res_blocks = num_res_blocks
567
+ self.resolution = resolution
568
+ self.in_channels = in_channels
569
+ self.give_pre_end = give_pre_end
570
+ self.enable_mid = enable_mid
571
+
572
+ # compute in_ch_mult, block_in and curr_res at lowest res
573
+ block_in = ch * ch_mult[self.num_resolutions - 1]
574
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
575
+ self.z_shape = (1, z_channels, curr_res, curr_res)
576
+ print(
577
+ "Working with z of shape {} = {} dimensions.".format(
578
+ self.z_shape, np.prod(self.z_shape)
579
+ )
580
+ )
581
+
582
+ # z to block_in
583
+ self.conv_in = torch.nn.Conv2d(
584
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
585
+ )
586
+
587
+ # middle
588
+ if self.enable_mid:
589
+ self.mid = nn.Module()
590
+ self.mid.block_1 = ResnetBlock(
591
+ in_channels=block_in,
592
+ out_channels=block_in,
593
+ temb_channels=self.temb_ch,
594
+ dropout=dropout,
595
+ )
596
+ self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
597
+ self.mid.block_2 = ResnetBlock(
598
+ in_channels=block_in,
599
+ out_channels=block_in,
600
+ temb_channels=self.temb_ch,
601
+ dropout=dropout,
602
+ )
603
+
604
+ # upsampling
605
+ self.up = nn.ModuleList()
606
+ for i_level in reversed(range(self.num_resolutions)):
607
+ block = nn.ModuleList()
608
+ attn = nn.ModuleList()
609
+ block_out = ch * ch_mult[i_level]
610
+ for i_block in range(self.num_res_blocks + 1):
611
+ block.append(
612
+ ResnetBlock(
613
+ in_channels=block_in,
614
+ out_channels=block_out,
615
+ temb_channels=self.temb_ch,
616
+ dropout=dropout,
617
+ )
618
+ )
619
+ block_in = block_out
620
+ if curr_res in attn_resolutions:
621
+ attn.append(MultiHeadAttnBlock(block_in, head_size))
622
+ up = nn.Module()
623
+ up.block = block
624
+ up.attn = attn
625
+ if i_level != 0:
626
+ up.upsample = Upsample(block_in, resamp_with_conv)
627
+ curr_res = curr_res * 2
628
+ self.up.insert(0, up) # prepend to get consistent order
629
+
630
+ # end
631
+ self.norm_out = Normalize(block_in)
632
+ self.conv_out = torch.nn.Conv2d(
633
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
634
+ )
635
+
636
+ def forward(self, z, hs):
637
+ # assert z.shape[1:] == self.z_shape[1:]
638
+ # self.last_z_shape = z.shape
639
+
640
+ # timestep embedding
641
+ temb = None
642
+
643
+ # z to block_in
644
+ h = self.conv_in(z)
645
+
646
+ # middle
647
+ if self.enable_mid:
648
+ h = self.mid.block_1(h, temb)
649
+ h = self.mid.attn_1(h, hs["mid_atten"])
650
+ h = self.mid.block_2(h, temb)
651
+
652
+ # upsampling
653
+ for i_level in reversed(range(self.num_resolutions)):
654
+ for i_block in range(self.num_res_blocks + 1):
655
+ h = self.up[i_level].block[i_block](h, temb)
656
+ if len(self.up[i_level].attn) > 0:
657
+ h = self.up[i_level].attn[i_block](
658
+ h, hs["block_" + str(i_level) + "_atten"]
659
+ )
660
+ # hfeature = h.clone()
661
+ if i_level != 0:
662
+ h = self.up[i_level].upsample(h)
663
+
664
+ # end
665
+ if self.give_pre_end:
666
+ return h
667
+
668
+ h = self.norm_out(h)
669
+ h = nonlinearity(h)
670
+ h = self.conv_out(h)
671
+ return h
672
+
673
+
674
+ class RestoreFormer(nn.Module):
675
+ def __init__(
676
+ self,
677
+ state_dict,
678
+ ):
679
+ super(RestoreFormer, self).__init__()
680
+
681
+ n_embed = 1024
682
+ embed_dim = 256
683
+ ch = 64
684
+ out_ch = 3
685
+ ch_mult = (1, 2, 2, 4, 4, 8)
686
+ num_res_blocks = 2
687
+ attn_resolutions = (16,)
688
+ dropout = 0.0
689
+ in_channels = 3
690
+ resolution = 512
691
+ z_channels = 256
692
+ double_z = False
693
+ enable_mid = True
694
+ fix_decoder = False
695
+ fix_codebook = True
696
+ fix_encoder = False
697
+ head_size = 8
698
+
699
+ self.model_arch = "RestoreFormer"
700
+ self.sub_type = "Face SR"
701
+ self.scale = 8
702
+ self.in_nc = 3
703
+ self.out_nc = out_ch
704
+ self.state = state_dict
705
+
706
+ self.supports_fp16 = False
707
+ self.supports_bf16 = True
708
+ self.min_size_restriction = 16
709
+
710
+ self.encoder = MultiHeadEncoder(
711
+ ch=ch,
712
+ out_ch=out_ch,
713
+ ch_mult=ch_mult,
714
+ num_res_blocks=num_res_blocks,
715
+ attn_resolutions=attn_resolutions,
716
+ dropout=dropout,
717
+ in_channels=in_channels,
718
+ resolution=resolution,
719
+ z_channels=z_channels,
720
+ double_z=double_z,
721
+ enable_mid=enable_mid,
722
+ head_size=head_size,
723
+ )
724
+ self.decoder = MultiHeadDecoderTransformer(
725
+ ch=ch,
726
+ out_ch=out_ch,
727
+ ch_mult=ch_mult,
728
+ num_res_blocks=num_res_blocks,
729
+ attn_resolutions=attn_resolutions,
730
+ dropout=dropout,
731
+ in_channels=in_channels,
732
+ resolution=resolution,
733
+ z_channels=z_channels,
734
+ enable_mid=enable_mid,
735
+ head_size=head_size,
736
+ )
737
+
738
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
739
+
740
+ self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
741
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
742
+
743
+ if fix_decoder:
744
+ for _, param in self.decoder.named_parameters():
745
+ param.requires_grad = False
746
+ for _, param in self.post_quant_conv.named_parameters():
747
+ param.requires_grad = False
748
+ for _, param in self.quantize.named_parameters():
749
+ param.requires_grad = False
750
+ elif fix_codebook:
751
+ for _, param in self.quantize.named_parameters():
752
+ param.requires_grad = False
753
+
754
+ if fix_encoder:
755
+ for _, param in self.encoder.named_parameters():
756
+ param.requires_grad = False
757
+
758
+ self.load_state_dict(state_dict)
759
+
760
+ def encode(self, x):
761
+ hs = self.encoder(x)
762
+ h = self.quant_conv(hs["out"])
763
+ quant, emb_loss, info = self.quantize(h)
764
+ return quant, emb_loss, info, hs
765
+
766
+ def decode(self, quant, hs):
767
+ quant = self.post_quant_conv(quant)
768
+ dec = self.decoder(quant, hs)
769
+
770
+ return dec
771
+
772
+ def forward(self, input, **kwargs):
773
+ quant, diff, info, hs = self.encode(input)
774
+ dec = self.decode(quant, hs)
775
+
776
+ return dec, None
comfy_extras/chainner_models/architecture/face/stylegan2_arch.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ import math
4
+ import random
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
11
+ from .upfirdn2d import upfirdn2d
12
+
13
+
14
+ class NormStyleCode(nn.Module):
15
+ def forward(self, x):
16
+ """Normalize the style codes.
17
+
18
+ Args:
19
+ x (Tensor): Style codes with shape (b, c).
20
+
21
+ Returns:
22
+ Tensor: Normalized tensor.
23
+ """
24
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
25
+
26
+
27
+ def make_resample_kernel(k):
28
+ """Make resampling kernel for UpFirDn.
29
+
30
+ Args:
31
+ k (list[int]): A list indicating the 1D resample kernel magnitude.
32
+
33
+ Returns:
34
+ Tensor: 2D resampled kernel.
35
+ """
36
+ k = torch.tensor(k, dtype=torch.float32)
37
+ if k.ndim == 1:
38
+ k = k[None, :] * k[:, None] # to 2D kernel, outer product
39
+ # normalize
40
+ k /= k.sum()
41
+ return k
42
+
43
+
44
+ class UpFirDnUpsample(nn.Module):
45
+ """Upsample, FIR filter, and downsample (upsampole version).
46
+
47
+ References:
48
+ 1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501
49
+ 2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501
50
+
51
+ Args:
52
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
53
+ magnitude.
54
+ factor (int): Upsampling scale factor. Default: 2.
55
+ """
56
+
57
+ def __init__(self, resample_kernel, factor=2):
58
+ super(UpFirDnUpsample, self).__init__()
59
+ self.kernel = make_resample_kernel(resample_kernel) * (factor**2)
60
+ self.factor = factor
61
+
62
+ pad = self.kernel.shape[0] - factor
63
+ self.pad = ((pad + 1) // 2 + factor - 1, pad // 2)
64
+
65
+ def forward(self, x):
66
+ out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad)
67
+ return out
68
+
69
+ def __repr__(self):
70
+ return f"{self.__class__.__name__}(factor={self.factor})"
71
+
72
+
73
+ class UpFirDnDownsample(nn.Module):
74
+ """Upsample, FIR filter, and downsample (downsampole version).
75
+
76
+ Args:
77
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
78
+ magnitude.
79
+ factor (int): Downsampling scale factor. Default: 2.
80
+ """
81
+
82
+ def __init__(self, resample_kernel, factor=2):
83
+ super(UpFirDnDownsample, self).__init__()
84
+ self.kernel = make_resample_kernel(resample_kernel)
85
+ self.factor = factor
86
+
87
+ pad = self.kernel.shape[0] - factor
88
+ self.pad = ((pad + 1) // 2, pad // 2)
89
+
90
+ def forward(self, x):
91
+ out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad)
92
+ return out
93
+
94
+ def __repr__(self):
95
+ return f"{self.__class__.__name__}(factor={self.factor})"
96
+
97
+
98
+ class UpFirDnSmooth(nn.Module):
99
+ """Upsample, FIR filter, and downsample (smooth version).
100
+
101
+ Args:
102
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
103
+ magnitude.
104
+ upsample_factor (int): Upsampling scale factor. Default: 1.
105
+ downsample_factor (int): Downsampling scale factor. Default: 1.
106
+ kernel_size (int): Kernel size: Default: 1.
107
+ """
108
+
109
+ def __init__(
110
+ self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1
111
+ ):
112
+ super(UpFirDnSmooth, self).__init__()
113
+ self.upsample_factor = upsample_factor
114
+ self.downsample_factor = downsample_factor
115
+ self.kernel = make_resample_kernel(resample_kernel)
116
+ if upsample_factor > 1:
117
+ self.kernel = self.kernel * (upsample_factor**2)
118
+
119
+ if upsample_factor > 1:
120
+ pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1)
121
+ self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1)
122
+ elif downsample_factor > 1:
123
+ pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1)
124
+ self.pad = ((pad + 1) // 2, pad // 2)
125
+ else:
126
+ raise NotImplementedError
127
+
128
+ def forward(self, x):
129
+ out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad)
130
+ return out
131
+
132
+ def __repr__(self):
133
+ return (
134
+ f"{self.__class__.__name__}(upsample_factor={self.upsample_factor}"
135
+ f", downsample_factor={self.downsample_factor})"
136
+ )
137
+
138
+
139
+ class EqualLinear(nn.Module):
140
+ """Equalized Linear as StyleGAN2.
141
+
142
+ Args:
143
+ in_channels (int): Size of each sample.
144
+ out_channels (int): Size of each output sample.
145
+ bias (bool): If set to ``False``, the layer will not learn an additive
146
+ bias. Default: ``True``.
147
+ bias_init_val (float): Bias initialized value. Default: 0.
148
+ lr_mul (float): Learning rate multiplier. Default: 1.
149
+ activation (None | str): The activation after ``linear`` operation.
150
+ Supported: 'fused_lrelu', None. Default: None.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ in_channels,
156
+ out_channels,
157
+ bias=True,
158
+ bias_init_val=0,
159
+ lr_mul=1,
160
+ activation=None,
161
+ ):
162
+ super(EqualLinear, self).__init__()
163
+ self.in_channels = in_channels
164
+ self.out_channels = out_channels
165
+ self.lr_mul = lr_mul
166
+ self.activation = activation
167
+ if self.activation not in ["fused_lrelu", None]:
168
+ raise ValueError(
169
+ f"Wrong activation value in EqualLinear: {activation}"
170
+ "Supported ones are: ['fused_lrelu', None]."
171
+ )
172
+ self.scale = (1 / math.sqrt(in_channels)) * lr_mul
173
+
174
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
175
+ if bias:
176
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
177
+ else:
178
+ self.register_parameter("bias", None)
179
+
180
+ def forward(self, x):
181
+ if self.bias is None:
182
+ bias = None
183
+ else:
184
+ bias = self.bias * self.lr_mul
185
+ if self.activation == "fused_lrelu":
186
+ out = F.linear(x, self.weight * self.scale)
187
+ out = fused_leaky_relu(out, bias)
188
+ else:
189
+ out = F.linear(x, self.weight * self.scale, bias=bias)
190
+ return out
191
+
192
+ def __repr__(self):
193
+ return (
194
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, "
195
+ f"out_channels={self.out_channels}, bias={self.bias is not None})"
196
+ )
197
+
198
+
199
+ class ModulatedConv2d(nn.Module):
200
+ """Modulated Conv2d used in StyleGAN2.
201
+
202
+ There is no bias in ModulatedConv2d.
203
+
204
+ Args:
205
+ in_channels (int): Channel number of the input.
206
+ out_channels (int): Channel number of the output.
207
+ kernel_size (int): Size of the convolving kernel.
208
+ num_style_feat (int): Channel number of style features.
209
+ demodulate (bool): Whether to demodulate in the conv layer.
210
+ Default: True.
211
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
212
+ Default: None.
213
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
214
+ magnitude. Default: (1, 3, 3, 1).
215
+ eps (float): A value added to the denominator for numerical stability.
216
+ Default: 1e-8.
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ in_channels,
222
+ out_channels,
223
+ kernel_size,
224
+ num_style_feat,
225
+ demodulate=True,
226
+ sample_mode=None,
227
+ resample_kernel=(1, 3, 3, 1),
228
+ eps=1e-8,
229
+ ):
230
+ super(ModulatedConv2d, self).__init__()
231
+ self.in_channels = in_channels
232
+ self.out_channels = out_channels
233
+ self.kernel_size = kernel_size
234
+ self.demodulate = demodulate
235
+ self.sample_mode = sample_mode
236
+ self.eps = eps
237
+
238
+ if self.sample_mode == "upsample":
239
+ self.smooth = UpFirDnSmooth(
240
+ resample_kernel,
241
+ upsample_factor=2,
242
+ downsample_factor=1,
243
+ kernel_size=kernel_size,
244
+ )
245
+ elif self.sample_mode == "downsample":
246
+ self.smooth = UpFirDnSmooth(
247
+ resample_kernel,
248
+ upsample_factor=1,
249
+ downsample_factor=2,
250
+ kernel_size=kernel_size,
251
+ )
252
+ elif self.sample_mode is None:
253
+ pass
254
+ else:
255
+ raise ValueError(
256
+ f"Wrong sample mode {self.sample_mode}, "
257
+ "supported ones are ['upsample', 'downsample', None]."
258
+ )
259
+
260
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
261
+ # modulation inside each modulated conv
262
+ self.modulation = EqualLinear(
263
+ num_style_feat,
264
+ in_channels,
265
+ bias=True,
266
+ bias_init_val=1,
267
+ lr_mul=1,
268
+ activation=None,
269
+ )
270
+
271
+ self.weight = nn.Parameter(
272
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
273
+ )
274
+ self.padding = kernel_size // 2
275
+
276
+ def forward(self, x, style):
277
+ """Forward function.
278
+
279
+ Args:
280
+ x (Tensor): Tensor with shape (b, c, h, w).
281
+ style (Tensor): Tensor with shape (b, num_style_feat).
282
+
283
+ Returns:
284
+ Tensor: Modulated tensor after convolution.
285
+ """
286
+ b, c, h, w = x.shape # c = c_in
287
+ # weight modulation
288
+ style = self.modulation(style).view(b, 1, c, 1, 1)
289
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
290
+ weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
291
+
292
+ if self.demodulate:
293
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
294
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
295
+
296
+ weight = weight.view(
297
+ b * self.out_channels, c, self.kernel_size, self.kernel_size
298
+ )
299
+
300
+ if self.sample_mode == "upsample":
301
+ x = x.view(1, b * c, h, w)
302
+ weight = weight.view(
303
+ b, self.out_channels, c, self.kernel_size, self.kernel_size
304
+ )
305
+ weight = weight.transpose(1, 2).reshape(
306
+ b * c, self.out_channels, self.kernel_size, self.kernel_size
307
+ )
308
+ out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
309
+ out = out.view(b, self.out_channels, *out.shape[2:4])
310
+ out = self.smooth(out)
311
+ elif self.sample_mode == "downsample":
312
+ x = self.smooth(x)
313
+ x = x.view(1, b * c, *x.shape[2:4])
314
+ out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
315
+ out = out.view(b, self.out_channels, *out.shape[2:4])
316
+ else:
317
+ x = x.view(1, b * c, h, w)
318
+ # weight: (b*c_out, c_in, k, k), groups=b
319
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
320
+ out = out.view(b, self.out_channels, *out.shape[2:4])
321
+
322
+ return out
323
+
324
+ def __repr__(self):
325
+ return (
326
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, "
327
+ f"out_channels={self.out_channels}, "
328
+ f"kernel_size={self.kernel_size}, "
329
+ f"demodulate={self.demodulate}, sample_mode={self.sample_mode})"
330
+ )
331
+
332
+
333
+ class StyleConv(nn.Module):
334
+ """Style conv.
335
+
336
+ Args:
337
+ in_channels (int): Channel number of the input.
338
+ out_channels (int): Channel number of the output.
339
+ kernel_size (int): Size of the convolving kernel.
340
+ num_style_feat (int): Channel number of style features.
341
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
342
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
343
+ Default: None.
344
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
345
+ magnitude. Default: (1, 3, 3, 1).
346
+ """
347
+
348
+ def __init__(
349
+ self,
350
+ in_channels,
351
+ out_channels,
352
+ kernel_size,
353
+ num_style_feat,
354
+ demodulate=True,
355
+ sample_mode=None,
356
+ resample_kernel=(1, 3, 3, 1),
357
+ ):
358
+ super(StyleConv, self).__init__()
359
+ self.modulated_conv = ModulatedConv2d(
360
+ in_channels,
361
+ out_channels,
362
+ kernel_size,
363
+ num_style_feat,
364
+ demodulate=demodulate,
365
+ sample_mode=sample_mode,
366
+ resample_kernel=resample_kernel,
367
+ )
368
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
369
+ self.activate = FusedLeakyReLU(out_channels)
370
+
371
+ def forward(self, x, style, noise=None):
372
+ # modulate
373
+ out = self.modulated_conv(x, style)
374
+ # noise injection
375
+ if noise is None:
376
+ b, _, h, w = out.shape
377
+ noise = out.new_empty(b, 1, h, w).normal_()
378
+ out = out + self.weight * noise
379
+ # activation (with bias)
380
+ out = self.activate(out)
381
+ return out
382
+
383
+
384
+ class ToRGB(nn.Module):
385
+ """To RGB from features.
386
+
387
+ Args:
388
+ in_channels (int): Channel number of input.
389
+ num_style_feat (int): Channel number of style features.
390
+ upsample (bool): Whether to upsample. Default: True.
391
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
392
+ magnitude. Default: (1, 3, 3, 1).
393
+ """
394
+
395
+ def __init__(
396
+ self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)
397
+ ):
398
+ super(ToRGB, self).__init__()
399
+ if upsample:
400
+ self.upsample = UpFirDnUpsample(resample_kernel, factor=2)
401
+ else:
402
+ self.upsample = None
403
+ self.modulated_conv = ModulatedConv2d(
404
+ in_channels,
405
+ 3,
406
+ kernel_size=1,
407
+ num_style_feat=num_style_feat,
408
+ demodulate=False,
409
+ sample_mode=None,
410
+ )
411
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
412
+
413
+ def forward(self, x, style, skip=None):
414
+ """Forward function.
415
+
416
+ Args:
417
+ x (Tensor): Feature tensor with shape (b, c, h, w).
418
+ style (Tensor): Tensor with shape (b, num_style_feat).
419
+ skip (Tensor): Base/skip tensor. Default: None.
420
+
421
+ Returns:
422
+ Tensor: RGB images.
423
+ """
424
+ out = self.modulated_conv(x, style)
425
+ out = out + self.bias
426
+ if skip is not None:
427
+ if self.upsample:
428
+ skip = self.upsample(skip)
429
+ out = out + skip
430
+ return out
431
+
432
+
433
+ class ConstantInput(nn.Module):
434
+ """Constant input.
435
+
436
+ Args:
437
+ num_channel (int): Channel number of constant input.
438
+ size (int): Spatial size of constant input.
439
+ """
440
+
441
+ def __init__(self, num_channel, size):
442
+ super(ConstantInput, self).__init__()
443
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
444
+
445
+ def forward(self, batch):
446
+ out = self.weight.repeat(batch, 1, 1, 1)
447
+ return out
448
+
449
+
450
+ class StyleGAN2Generator(nn.Module):
451
+ """StyleGAN2 Generator.
452
+
453
+ Args:
454
+ out_size (int): The spatial size of outputs.
455
+ num_style_feat (int): Channel number of style features. Default: 512.
456
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
457
+ channel_multiplier (int): Channel multiplier for large networks of
458
+ StyleGAN2. Default: 2.
459
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
460
+ magnitude. A cross production will be applied to extent 1D resample
461
+ kernel to 2D resample kernel. Default: (1, 3, 3, 1).
462
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
463
+ narrow (float): Narrow ratio for channels. Default: 1.0.
464
+ """
465
+
466
+ def __init__(
467
+ self,
468
+ out_size,
469
+ num_style_feat=512,
470
+ num_mlp=8,
471
+ channel_multiplier=2,
472
+ resample_kernel=(1, 3, 3, 1),
473
+ lr_mlp=0.01,
474
+ narrow=1,
475
+ ):
476
+ super(StyleGAN2Generator, self).__init__()
477
+ # Style MLP layers
478
+ self.num_style_feat = num_style_feat
479
+ style_mlp_layers = [NormStyleCode()]
480
+ for i in range(num_mlp):
481
+ style_mlp_layers.append(
482
+ EqualLinear(
483
+ num_style_feat,
484
+ num_style_feat,
485
+ bias=True,
486
+ bias_init_val=0,
487
+ lr_mul=lr_mlp,
488
+ activation="fused_lrelu",
489
+ )
490
+ )
491
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
492
+
493
+ channels = {
494
+ "4": int(512 * narrow),
495
+ "8": int(512 * narrow),
496
+ "16": int(512 * narrow),
497
+ "32": int(512 * narrow),
498
+ "64": int(256 * channel_multiplier * narrow),
499
+ "128": int(128 * channel_multiplier * narrow),
500
+ "256": int(64 * channel_multiplier * narrow),
501
+ "512": int(32 * channel_multiplier * narrow),
502
+ "1024": int(16 * channel_multiplier * narrow),
503
+ }
504
+ self.channels = channels
505
+
506
+ self.constant_input = ConstantInput(channels["4"], size=4)
507
+ self.style_conv1 = StyleConv(
508
+ channels["4"],
509
+ channels["4"],
510
+ kernel_size=3,
511
+ num_style_feat=num_style_feat,
512
+ demodulate=True,
513
+ sample_mode=None,
514
+ resample_kernel=resample_kernel,
515
+ )
516
+ self.to_rgb1 = ToRGB(
517
+ channels["4"],
518
+ num_style_feat,
519
+ upsample=False,
520
+ resample_kernel=resample_kernel,
521
+ )
522
+
523
+ self.log_size = int(math.log(out_size, 2))
524
+ self.num_layers = (self.log_size - 2) * 2 + 1
525
+ self.num_latent = self.log_size * 2 - 2
526
+
527
+ self.style_convs = nn.ModuleList()
528
+ self.to_rgbs = nn.ModuleList()
529
+ self.noises = nn.Module()
530
+
531
+ in_channels = channels["4"]
532
+ # noise
533
+ for layer_idx in range(self.num_layers):
534
+ resolution = 2 ** ((layer_idx + 5) // 2)
535
+ shape = [1, 1, resolution, resolution]
536
+ self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
537
+ # style convs and to_rgbs
538
+ for i in range(3, self.log_size + 1):
539
+ out_channels = channels[f"{2**i}"]
540
+ self.style_convs.append(
541
+ StyleConv(
542
+ in_channels,
543
+ out_channels,
544
+ kernel_size=3,
545
+ num_style_feat=num_style_feat,
546
+ demodulate=True,
547
+ sample_mode="upsample",
548
+ resample_kernel=resample_kernel,
549
+ )
550
+ )
551
+ self.style_convs.append(
552
+ StyleConv(
553
+ out_channels,
554
+ out_channels,
555
+ kernel_size=3,
556
+ num_style_feat=num_style_feat,
557
+ demodulate=True,
558
+ sample_mode=None,
559
+ resample_kernel=resample_kernel,
560
+ )
561
+ )
562
+ self.to_rgbs.append(
563
+ ToRGB(
564
+ out_channels,
565
+ num_style_feat,
566
+ upsample=True,
567
+ resample_kernel=resample_kernel,
568
+ )
569
+ )
570
+ in_channels = out_channels
571
+
572
+ def make_noise(self):
573
+ """Make noise for noise injection."""
574
+ device = self.constant_input.weight.device
575
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
576
+
577
+ for i in range(3, self.log_size + 1):
578
+ for _ in range(2):
579
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
580
+
581
+ return noises
582
+
583
+ def get_latent(self, x):
584
+ return self.style_mlp(x)
585
+
586
+ def mean_latent(self, num_latent):
587
+ latent_in = torch.randn(
588
+ num_latent, self.num_style_feat, device=self.constant_input.weight.device
589
+ )
590
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
591
+ return latent
592
+
593
+ def forward(
594
+ self,
595
+ styles,
596
+ input_is_latent=False,
597
+ noise=None,
598
+ randomize_noise=True,
599
+ truncation=1,
600
+ truncation_latent=None,
601
+ inject_index=None,
602
+ return_latents=False,
603
+ ):
604
+ """Forward function for StyleGAN2Generator.
605
+
606
+ Args:
607
+ styles (list[Tensor]): Sample codes of styles.
608
+ input_is_latent (bool): Whether input is latent style.
609
+ Default: False.
610
+ noise (Tensor | None): Input noise or None. Default: None.
611
+ randomize_noise (bool): Randomize noise, used when 'noise' is
612
+ False. Default: True.
613
+ truncation (float): TODO. Default: 1.
614
+ truncation_latent (Tensor | None): TODO. Default: None.
615
+ inject_index (int | None): The injection index for mixing noise.
616
+ Default: None.
617
+ return_latents (bool): Whether to return style latents.
618
+ Default: False.
619
+ """
620
+ # style codes -> latents with Style MLP layer
621
+ if not input_is_latent:
622
+ styles = [self.style_mlp(s) for s in styles]
623
+ # noises
624
+ if noise is None:
625
+ if randomize_noise:
626
+ noise = [None] * self.num_layers # for each style conv layer
627
+ else: # use the stored noise
628
+ noise = [
629
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
630
+ ]
631
+ # style truncation
632
+ if truncation < 1:
633
+ style_truncation = []
634
+ for style in styles:
635
+ style_truncation.append(
636
+ truncation_latent + truncation * (style - truncation_latent)
637
+ )
638
+ styles = style_truncation
639
+ # get style latent with injection
640
+ if len(styles) == 1:
641
+ inject_index = self.num_latent
642
+
643
+ if styles[0].ndim < 3:
644
+ # repeat latent code for all the layers
645
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
646
+ else: # used for encoder with different latent code for each layer
647
+ latent = styles[0]
648
+ elif len(styles) == 2: # mixing noises
649
+ if inject_index is None:
650
+ inject_index = random.randint(1, self.num_latent - 1)
651
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
652
+ latent2 = (
653
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
654
+ )
655
+ latent = torch.cat([latent1, latent2], 1)
656
+
657
+ # main generation
658
+ out = self.constant_input(latent.shape[0])
659
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
660
+ skip = self.to_rgb1(out, latent[:, 1])
661
+
662
+ i = 1
663
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
664
+ self.style_convs[::2],
665
+ self.style_convs[1::2],
666
+ noise[1::2],
667
+ noise[2::2],
668
+ self.to_rgbs,
669
+ ):
670
+ out = conv1(out, latent[:, i], noise=noise1)
671
+ out = conv2(out, latent[:, i + 1], noise=noise2)
672
+ skip = to_rgb(out, latent[:, i + 2], skip)
673
+ i += 2
674
+
675
+ image = skip
676
+
677
+ if return_latents:
678
+ return image, latent
679
+ else:
680
+ return image, None
681
+
682
+
683
+ class ScaledLeakyReLU(nn.Module):
684
+ """Scaled LeakyReLU.
685
+
686
+ Args:
687
+ negative_slope (float): Negative slope. Default: 0.2.
688
+ """
689
+
690
+ def __init__(self, negative_slope=0.2):
691
+ super(ScaledLeakyReLU, self).__init__()
692
+ self.negative_slope = negative_slope
693
+
694
+ def forward(self, x):
695
+ out = F.leaky_relu(x, negative_slope=self.negative_slope)
696
+ return out * math.sqrt(2)
697
+
698
+
699
+ class EqualConv2d(nn.Module):
700
+ """Equalized Linear as StyleGAN2.
701
+
702
+ Args:
703
+ in_channels (int): Channel number of the input.
704
+ out_channels (int): Channel number of the output.
705
+ kernel_size (int): Size of the convolving kernel.
706
+ stride (int): Stride of the convolution. Default: 1
707
+ padding (int): Zero-padding added to both sides of the input.
708
+ Default: 0.
709
+ bias (bool): If ``True``, adds a learnable bias to the output.
710
+ Default: ``True``.
711
+ bias_init_val (float): Bias initialized value. Default: 0.
712
+ """
713
+
714
+ def __init__(
715
+ self,
716
+ in_channels,
717
+ out_channels,
718
+ kernel_size,
719
+ stride=1,
720
+ padding=0,
721
+ bias=True,
722
+ bias_init_val=0,
723
+ ):
724
+ super(EqualConv2d, self).__init__()
725
+ self.in_channels = in_channels
726
+ self.out_channels = out_channels
727
+ self.kernel_size = kernel_size
728
+ self.stride = stride
729
+ self.padding = padding
730
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
731
+
732
+ self.weight = nn.Parameter(
733
+ torch.randn(out_channels, in_channels, kernel_size, kernel_size)
734
+ )
735
+ if bias:
736
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
737
+ else:
738
+ self.register_parameter("bias", None)
739
+
740
+ def forward(self, x):
741
+ out = F.conv2d(
742
+ x,
743
+ self.weight * self.scale,
744
+ bias=self.bias,
745
+ stride=self.stride,
746
+ padding=self.padding,
747
+ )
748
+
749
+ return out
750
+
751
+ def __repr__(self):
752
+ return (
753
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, "
754
+ f"out_channels={self.out_channels}, "
755
+ f"kernel_size={self.kernel_size},"
756
+ f" stride={self.stride}, padding={self.padding}, "
757
+ f"bias={self.bias is not None})"
758
+ )
759
+
760
+
761
+ class ConvLayer(nn.Sequential):
762
+ """Conv Layer used in StyleGAN2 Discriminator.
763
+
764
+ Args:
765
+ in_channels (int): Channel number of the input.
766
+ out_channels (int): Channel number of the output.
767
+ kernel_size (int): Kernel size.
768
+ downsample (bool): Whether downsample by a factor of 2.
769
+ Default: False.
770
+ resample_kernel (list[int]): A list indicating the 1D resample
771
+ kernel magnitude. A cross production will be applied to
772
+ extent 1D resample kernel to 2D resample kernel.
773
+ Default: (1, 3, 3, 1).
774
+ bias (bool): Whether with bias. Default: True.
775
+ activate (bool): Whether use activateion. Default: True.
776
+ """
777
+
778
+ def __init__(
779
+ self,
780
+ in_channels,
781
+ out_channels,
782
+ kernel_size,
783
+ downsample=False,
784
+ resample_kernel=(1, 3, 3, 1),
785
+ bias=True,
786
+ activate=True,
787
+ ):
788
+ layers = []
789
+ # downsample
790
+ if downsample:
791
+ layers.append(
792
+ UpFirDnSmooth(
793
+ resample_kernel,
794
+ upsample_factor=1,
795
+ downsample_factor=2,
796
+ kernel_size=kernel_size,
797
+ )
798
+ )
799
+ stride = 2
800
+ self.padding = 0
801
+ else:
802
+ stride = 1
803
+ self.padding = kernel_size // 2
804
+ # conv
805
+ layers.append(
806
+ EqualConv2d(
807
+ in_channels,
808
+ out_channels,
809
+ kernel_size,
810
+ stride=stride,
811
+ padding=self.padding,
812
+ bias=bias and not activate,
813
+ )
814
+ )
815
+ # activation
816
+ if activate:
817
+ if bias:
818
+ layers.append(FusedLeakyReLU(out_channels))
819
+ else:
820
+ layers.append(ScaledLeakyReLU(0.2))
821
+
822
+ super(ConvLayer, self).__init__(*layers)
823
+
824
+
825
+ class ResBlock(nn.Module):
826
+ """Residual block used in StyleGAN2 Discriminator.
827
+
828
+ Args:
829
+ in_channels (int): Channel number of the input.
830
+ out_channels (int): Channel number of the output.
831
+ resample_kernel (list[int]): A list indicating the 1D resample
832
+ kernel magnitude. A cross production will be applied to
833
+ extent 1D resample kernel to 2D resample kernel.
834
+ Default: (1, 3, 3, 1).
835
+ """
836
+
837
+ def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
838
+ super(ResBlock, self).__init__()
839
+
840
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
841
+ self.conv2 = ConvLayer(
842
+ in_channels,
843
+ out_channels,
844
+ 3,
845
+ downsample=True,
846
+ resample_kernel=resample_kernel,
847
+ bias=True,
848
+ activate=True,
849
+ )
850
+ self.skip = ConvLayer(
851
+ in_channels,
852
+ out_channels,
853
+ 1,
854
+ downsample=True,
855
+ resample_kernel=resample_kernel,
856
+ bias=False,
857
+ activate=False,
858
+ )
859
+
860
+ def forward(self, x):
861
+ out = self.conv1(x)
862
+ out = self.conv2(out)
863
+ skip = self.skip(x)
864
+ out = (out + skip) / math.sqrt(2)
865
+ return out
comfy_extras/chainner_models/architecture/face/stylegan2_bilinear_arch.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ import math
4
+ import random
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
11
+
12
+
13
+ class NormStyleCode(nn.Module):
14
+ def forward(self, x):
15
+ """Normalize the style codes.
16
+ Args:
17
+ x (Tensor): Style codes with shape (b, c).
18
+ Returns:
19
+ Tensor: Normalized tensor.
20
+ """
21
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
22
+
23
+
24
+ class EqualLinear(nn.Module):
25
+ """Equalized Linear as StyleGAN2.
26
+ Args:
27
+ in_channels (int): Size of each sample.
28
+ out_channels (int): Size of each output sample.
29
+ bias (bool): If set to ``False``, the layer will not learn an additive
30
+ bias. Default: ``True``.
31
+ bias_init_val (float): Bias initialized value. Default: 0.
32
+ lr_mul (float): Learning rate multiplier. Default: 1.
33
+ activation (None | str): The activation after ``linear`` operation.
34
+ Supported: 'fused_lrelu', None. Default: None.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ in_channels,
40
+ out_channels,
41
+ bias=True,
42
+ bias_init_val=0,
43
+ lr_mul=1,
44
+ activation=None,
45
+ ):
46
+ super(EqualLinear, self).__init__()
47
+ self.in_channels = in_channels
48
+ self.out_channels = out_channels
49
+ self.lr_mul = lr_mul
50
+ self.activation = activation
51
+ if self.activation not in ["fused_lrelu", None]:
52
+ raise ValueError(
53
+ f"Wrong activation value in EqualLinear: {activation}"
54
+ "Supported ones are: ['fused_lrelu', None]."
55
+ )
56
+ self.scale = (1 / math.sqrt(in_channels)) * lr_mul
57
+
58
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
59
+ if bias:
60
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
61
+ else:
62
+ self.register_parameter("bias", None)
63
+
64
+ def forward(self, x):
65
+ if self.bias is None:
66
+ bias = None
67
+ else:
68
+ bias = self.bias * self.lr_mul
69
+ if self.activation == "fused_lrelu":
70
+ out = F.linear(x, self.weight * self.scale)
71
+ out = fused_leaky_relu(out, bias)
72
+ else:
73
+ out = F.linear(x, self.weight * self.scale, bias=bias)
74
+ return out
75
+
76
+ def __repr__(self):
77
+ return (
78
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, "
79
+ f"out_channels={self.out_channels}, bias={self.bias is not None})"
80
+ )
81
+
82
+
83
+ class ModulatedConv2d(nn.Module):
84
+ """Modulated Conv2d used in StyleGAN2.
85
+ There is no bias in ModulatedConv2d.
86
+ Args:
87
+ in_channels (int): Channel number of the input.
88
+ out_channels (int): Channel number of the output.
89
+ kernel_size (int): Size of the convolving kernel.
90
+ num_style_feat (int): Channel number of style features.
91
+ demodulate (bool): Whether to demodulate in the conv layer.
92
+ Default: True.
93
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
94
+ Default: None.
95
+ eps (float): A value added to the denominator for numerical stability.
96
+ Default: 1e-8.
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ in_channels,
102
+ out_channels,
103
+ kernel_size,
104
+ num_style_feat,
105
+ demodulate=True,
106
+ sample_mode=None,
107
+ eps=1e-8,
108
+ interpolation_mode="bilinear",
109
+ ):
110
+ super(ModulatedConv2d, self).__init__()
111
+ self.in_channels = in_channels
112
+ self.out_channels = out_channels
113
+ self.kernel_size = kernel_size
114
+ self.demodulate = demodulate
115
+ self.sample_mode = sample_mode
116
+ self.eps = eps
117
+ self.interpolation_mode = interpolation_mode
118
+ if self.interpolation_mode == "nearest":
119
+ self.align_corners = None
120
+ else:
121
+ self.align_corners = False
122
+
123
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
124
+ # modulation inside each modulated conv
125
+ self.modulation = EqualLinear(
126
+ num_style_feat,
127
+ in_channels,
128
+ bias=True,
129
+ bias_init_val=1,
130
+ lr_mul=1,
131
+ activation=None,
132
+ )
133
+
134
+ self.weight = nn.Parameter(
135
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
136
+ )
137
+ self.padding = kernel_size // 2
138
+
139
+ def forward(self, x, style):
140
+ """Forward function.
141
+ Args:
142
+ x (Tensor): Tensor with shape (b, c, h, w).
143
+ style (Tensor): Tensor with shape (b, num_style_feat).
144
+ Returns:
145
+ Tensor: Modulated tensor after convolution.
146
+ """
147
+ b, c, h, w = x.shape # c = c_in
148
+ # weight modulation
149
+ style = self.modulation(style).view(b, 1, c, 1, 1)
150
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
151
+ weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
152
+
153
+ if self.demodulate:
154
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
155
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
156
+
157
+ weight = weight.view(
158
+ b * self.out_channels, c, self.kernel_size, self.kernel_size
159
+ )
160
+
161
+ if self.sample_mode == "upsample":
162
+ x = F.interpolate(
163
+ x,
164
+ scale_factor=2,
165
+ mode=self.interpolation_mode,
166
+ align_corners=self.align_corners,
167
+ )
168
+ elif self.sample_mode == "downsample":
169
+ x = F.interpolate(
170
+ x,
171
+ scale_factor=0.5,
172
+ mode=self.interpolation_mode,
173
+ align_corners=self.align_corners,
174
+ )
175
+
176
+ b, c, h, w = x.shape
177
+ x = x.view(1, b * c, h, w)
178
+ # weight: (b*c_out, c_in, k, k), groups=b
179
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
180
+ out = out.view(b, self.out_channels, *out.shape[2:4])
181
+
182
+ return out
183
+
184
+ def __repr__(self):
185
+ return (
186
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, "
187
+ f"out_channels={self.out_channels}, "
188
+ f"kernel_size={self.kernel_size}, "
189
+ f"demodulate={self.demodulate}, sample_mode={self.sample_mode})"
190
+ )
191
+
192
+
193
+ class StyleConv(nn.Module):
194
+ """Style conv.
195
+ Args:
196
+ in_channels (int): Channel number of the input.
197
+ out_channels (int): Channel number of the output.
198
+ kernel_size (int): Size of the convolving kernel.
199
+ num_style_feat (int): Channel number of style features.
200
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
201
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
202
+ Default: None.
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ in_channels,
208
+ out_channels,
209
+ kernel_size,
210
+ num_style_feat,
211
+ demodulate=True,
212
+ sample_mode=None,
213
+ interpolation_mode="bilinear",
214
+ ):
215
+ super(StyleConv, self).__init__()
216
+ self.modulated_conv = ModulatedConv2d(
217
+ in_channels,
218
+ out_channels,
219
+ kernel_size,
220
+ num_style_feat,
221
+ demodulate=demodulate,
222
+ sample_mode=sample_mode,
223
+ interpolation_mode=interpolation_mode,
224
+ )
225
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
226
+ self.activate = FusedLeakyReLU(out_channels)
227
+
228
+ def forward(self, x, style, noise=None):
229
+ # modulate
230
+ out = self.modulated_conv(x, style)
231
+ # noise injection
232
+ if noise is None:
233
+ b, _, h, w = out.shape
234
+ noise = out.new_empty(b, 1, h, w).normal_()
235
+ out = out + self.weight * noise
236
+ # activation (with bias)
237
+ out = self.activate(out)
238
+ return out
239
+
240
+
241
+ class ToRGB(nn.Module):
242
+ """To RGB from features.
243
+ Args:
244
+ in_channels (int): Channel number of input.
245
+ num_style_feat (int): Channel number of style features.
246
+ upsample (bool): Whether to upsample. Default: True.
247
+ """
248
+
249
+ def __init__(
250
+ self, in_channels, num_style_feat, upsample=True, interpolation_mode="bilinear"
251
+ ):
252
+ super(ToRGB, self).__init__()
253
+ self.upsample = upsample
254
+ self.interpolation_mode = interpolation_mode
255
+ if self.interpolation_mode == "nearest":
256
+ self.align_corners = None
257
+ else:
258
+ self.align_corners = False
259
+ self.modulated_conv = ModulatedConv2d(
260
+ in_channels,
261
+ 3,
262
+ kernel_size=1,
263
+ num_style_feat=num_style_feat,
264
+ demodulate=False,
265
+ sample_mode=None,
266
+ interpolation_mode=interpolation_mode,
267
+ )
268
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
269
+
270
+ def forward(self, x, style, skip=None):
271
+ """Forward function.
272
+ Args:
273
+ x (Tensor): Feature tensor with shape (b, c, h, w).
274
+ style (Tensor): Tensor with shape (b, num_style_feat).
275
+ skip (Tensor): Base/skip tensor. Default: None.
276
+ Returns:
277
+ Tensor: RGB images.
278
+ """
279
+ out = self.modulated_conv(x, style)
280
+ out = out + self.bias
281
+ if skip is not None:
282
+ if self.upsample:
283
+ skip = F.interpolate(
284
+ skip,
285
+ scale_factor=2,
286
+ mode=self.interpolation_mode,
287
+ align_corners=self.align_corners,
288
+ )
289
+ out = out + skip
290
+ return out
291
+
292
+
293
+ class ConstantInput(nn.Module):
294
+ """Constant input.
295
+ Args:
296
+ num_channel (int): Channel number of constant input.
297
+ size (int): Spatial size of constant input.
298
+ """
299
+
300
+ def __init__(self, num_channel, size):
301
+ super(ConstantInput, self).__init__()
302
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
303
+
304
+ def forward(self, batch):
305
+ out = self.weight.repeat(batch, 1, 1, 1)
306
+ return out
307
+
308
+
309
+ class StyleGAN2GeneratorBilinear(nn.Module):
310
+ """StyleGAN2 Generator.
311
+ Args:
312
+ out_size (int): The spatial size of outputs.
313
+ num_style_feat (int): Channel number of style features. Default: 512.
314
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
315
+ channel_multiplier (int): Channel multiplier for large networks of
316
+ StyleGAN2. Default: 2.
317
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
318
+ narrow (float): Narrow ratio for channels. Default: 1.0.
319
+ """
320
+
321
+ def __init__(
322
+ self,
323
+ out_size,
324
+ num_style_feat=512,
325
+ num_mlp=8,
326
+ channel_multiplier=2,
327
+ lr_mlp=0.01,
328
+ narrow=1,
329
+ interpolation_mode="bilinear",
330
+ ):
331
+ super(StyleGAN2GeneratorBilinear, self).__init__()
332
+ # Style MLP layers
333
+ self.num_style_feat = num_style_feat
334
+ style_mlp_layers = [NormStyleCode()]
335
+ for i in range(num_mlp):
336
+ style_mlp_layers.append(
337
+ EqualLinear(
338
+ num_style_feat,
339
+ num_style_feat,
340
+ bias=True,
341
+ bias_init_val=0,
342
+ lr_mul=lr_mlp,
343
+ activation="fused_lrelu",
344
+ )
345
+ )
346
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
347
+
348
+ channels = {
349
+ "4": int(512 * narrow),
350
+ "8": int(512 * narrow),
351
+ "16": int(512 * narrow),
352
+ "32": int(512 * narrow),
353
+ "64": int(256 * channel_multiplier * narrow),
354
+ "128": int(128 * channel_multiplier * narrow),
355
+ "256": int(64 * channel_multiplier * narrow),
356
+ "512": int(32 * channel_multiplier * narrow),
357
+ "1024": int(16 * channel_multiplier * narrow),
358
+ }
359
+ self.channels = channels
360
+
361
+ self.constant_input = ConstantInput(channels["4"], size=4)
362
+ self.style_conv1 = StyleConv(
363
+ channels["4"],
364
+ channels["4"],
365
+ kernel_size=3,
366
+ num_style_feat=num_style_feat,
367
+ demodulate=True,
368
+ sample_mode=None,
369
+ interpolation_mode=interpolation_mode,
370
+ )
371
+ self.to_rgb1 = ToRGB(
372
+ channels["4"],
373
+ num_style_feat,
374
+ upsample=False,
375
+ interpolation_mode=interpolation_mode,
376
+ )
377
+
378
+ self.log_size = int(math.log(out_size, 2))
379
+ self.num_layers = (self.log_size - 2) * 2 + 1
380
+ self.num_latent = self.log_size * 2 - 2
381
+
382
+ self.style_convs = nn.ModuleList()
383
+ self.to_rgbs = nn.ModuleList()
384
+ self.noises = nn.Module()
385
+
386
+ in_channels = channels["4"]
387
+ # noise
388
+ for layer_idx in range(self.num_layers):
389
+ resolution = 2 ** ((layer_idx + 5) // 2)
390
+ shape = [1, 1, resolution, resolution]
391
+ self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
392
+ # style convs and to_rgbs
393
+ for i in range(3, self.log_size + 1):
394
+ out_channels = channels[f"{2**i}"]
395
+ self.style_convs.append(
396
+ StyleConv(
397
+ in_channels,
398
+ out_channels,
399
+ kernel_size=3,
400
+ num_style_feat=num_style_feat,
401
+ demodulate=True,
402
+ sample_mode="upsample",
403
+ interpolation_mode=interpolation_mode,
404
+ )
405
+ )
406
+ self.style_convs.append(
407
+ StyleConv(
408
+ out_channels,
409
+ out_channels,
410
+ kernel_size=3,
411
+ num_style_feat=num_style_feat,
412
+ demodulate=True,
413
+ sample_mode=None,
414
+ interpolation_mode=interpolation_mode,
415
+ )
416
+ )
417
+ self.to_rgbs.append(
418
+ ToRGB(
419
+ out_channels,
420
+ num_style_feat,
421
+ upsample=True,
422
+ interpolation_mode=interpolation_mode,
423
+ )
424
+ )
425
+ in_channels = out_channels
426
+
427
+ def make_noise(self):
428
+ """Make noise for noise injection."""
429
+ device = self.constant_input.weight.device
430
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
431
+
432
+ for i in range(3, self.log_size + 1):
433
+ for _ in range(2):
434
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
435
+
436
+ return noises
437
+
438
+ def get_latent(self, x):
439
+ return self.style_mlp(x)
440
+
441
+ def mean_latent(self, num_latent):
442
+ latent_in = torch.randn(
443
+ num_latent, self.num_style_feat, device=self.constant_input.weight.device
444
+ )
445
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
446
+ return latent
447
+
448
+ def forward(
449
+ self,
450
+ styles,
451
+ input_is_latent=False,
452
+ noise=None,
453
+ randomize_noise=True,
454
+ truncation=1,
455
+ truncation_latent=None,
456
+ inject_index=None,
457
+ return_latents=False,
458
+ ):
459
+ """Forward function for StyleGAN2Generator.
460
+ Args:
461
+ styles (list[Tensor]): Sample codes of styles.
462
+ input_is_latent (bool): Whether input is latent style.
463
+ Default: False.
464
+ noise (Tensor | None): Input noise or None. Default: None.
465
+ randomize_noise (bool): Randomize noise, used when 'noise' is
466
+ False. Default: True.
467
+ truncation (float): TODO. Default: 1.
468
+ truncation_latent (Tensor | None): TODO. Default: None.
469
+ inject_index (int | None): The injection index for mixing noise.
470
+ Default: None.
471
+ return_latents (bool): Whether to return style latents.
472
+ Default: False.
473
+ """
474
+ # style codes -> latents with Style MLP layer
475
+ if not input_is_latent:
476
+ styles = [self.style_mlp(s) for s in styles]
477
+ # noises
478
+ if noise is None:
479
+ if randomize_noise:
480
+ noise = [None] * self.num_layers # for each style conv layer
481
+ else: # use the stored noise
482
+ noise = [
483
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
484
+ ]
485
+ # style truncation
486
+ if truncation < 1:
487
+ style_truncation = []
488
+ for style in styles:
489
+ style_truncation.append(
490
+ truncation_latent + truncation * (style - truncation_latent)
491
+ )
492
+ styles = style_truncation
493
+ # get style latent with injection
494
+ if len(styles) == 1:
495
+ inject_index = self.num_latent
496
+
497
+ if styles[0].ndim < 3:
498
+ # repeat latent code for all the layers
499
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
500
+ else: # used for encoder with different latent code for each layer
501
+ latent = styles[0]
502
+ elif len(styles) == 2: # mixing noises
503
+ if inject_index is None:
504
+ inject_index = random.randint(1, self.num_latent - 1)
505
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
506
+ latent2 = (
507
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
508
+ )
509
+ latent = torch.cat([latent1, latent2], 1)
510
+
511
+ # main generation
512
+ out = self.constant_input(latent.shape[0])
513
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
514
+ skip = self.to_rgb1(out, latent[:, 1])
515
+
516
+ i = 1
517
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
518
+ self.style_convs[::2],
519
+ self.style_convs[1::2],
520
+ noise[1::2],
521
+ noise[2::2],
522
+ self.to_rgbs,
523
+ ):
524
+ out = conv1(out, latent[:, i], noise=noise1)
525
+ out = conv2(out, latent[:, i + 1], noise=noise2)
526
+ skip = to_rgb(out, latent[:, i + 2], skip)
527
+ i += 2
528
+
529
+ image = skip
530
+
531
+ if return_latents:
532
+ return image, latent
533
+ else:
534
+ return image, None
535
+
536
+
537
+ class ScaledLeakyReLU(nn.Module):
538
+ """Scaled LeakyReLU.
539
+ Args:
540
+ negative_slope (float): Negative slope. Default: 0.2.
541
+ """
542
+
543
+ def __init__(self, negative_slope=0.2):
544
+ super(ScaledLeakyReLU, self).__init__()
545
+ self.negative_slope = negative_slope
546
+
547
+ def forward(self, x):
548
+ out = F.leaky_relu(x, negative_slope=self.negative_slope)
549
+ return out * math.sqrt(2)
550
+
551
+
552
+ class EqualConv2d(nn.Module):
553
+ """Equalized Linear as StyleGAN2.
554
+ Args:
555
+ in_channels (int): Channel number of the input.
556
+ out_channels (int): Channel number of the output.
557
+ kernel_size (int): Size of the convolving kernel.
558
+ stride (int): Stride of the convolution. Default: 1
559
+ padding (int): Zero-padding added to both sides of the input.
560
+ Default: 0.
561
+ bias (bool): If ``True``, adds a learnable bias to the output.
562
+ Default: ``True``.
563
+ bias_init_val (float): Bias initialized value. Default: 0.
564
+ """
565
+
566
+ def __init__(
567
+ self,
568
+ in_channels,
569
+ out_channels,
570
+ kernel_size,
571
+ stride=1,
572
+ padding=0,
573
+ bias=True,
574
+ bias_init_val=0,
575
+ ):
576
+ super(EqualConv2d, self).__init__()
577
+ self.in_channels = in_channels
578
+ self.out_channels = out_channels
579
+ self.kernel_size = kernel_size
580
+ self.stride = stride
581
+ self.padding = padding
582
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
583
+
584
+ self.weight = nn.Parameter(
585
+ torch.randn(out_channels, in_channels, kernel_size, kernel_size)
586
+ )
587
+ if bias:
588
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
589
+ else:
590
+ self.register_parameter("bias", None)
591
+
592
+ def forward(self, x):
593
+ out = F.conv2d(
594
+ x,
595
+ self.weight * self.scale,
596
+ bias=self.bias,
597
+ stride=self.stride,
598
+ padding=self.padding,
599
+ )
600
+
601
+ return out
602
+
603
+ def __repr__(self):
604
+ return (
605
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, "
606
+ f"out_channels={self.out_channels}, "
607
+ f"kernel_size={self.kernel_size},"
608
+ f" stride={self.stride}, padding={self.padding}, "
609
+ f"bias={self.bias is not None})"
610
+ )
611
+
612
+
613
+ class ConvLayer(nn.Sequential):
614
+ """Conv Layer used in StyleGAN2 Discriminator.
615
+ Args:
616
+ in_channels (int): Channel number of the input.
617
+ out_channels (int): Channel number of the output.
618
+ kernel_size (int): Kernel size.
619
+ downsample (bool): Whether downsample by a factor of 2.
620
+ Default: False.
621
+ bias (bool): Whether with bias. Default: True.
622
+ activate (bool): Whether use activateion. Default: True.
623
+ """
624
+
625
+ def __init__(
626
+ self,
627
+ in_channels,
628
+ out_channels,
629
+ kernel_size,
630
+ downsample=False,
631
+ bias=True,
632
+ activate=True,
633
+ interpolation_mode="bilinear",
634
+ ):
635
+ layers = []
636
+ self.interpolation_mode = interpolation_mode
637
+ # downsample
638
+ if downsample:
639
+ if self.interpolation_mode == "nearest":
640
+ self.align_corners = None
641
+ else:
642
+ self.align_corners = False
643
+
644
+ layers.append(
645
+ torch.nn.Upsample(
646
+ scale_factor=0.5,
647
+ mode=interpolation_mode,
648
+ align_corners=self.align_corners,
649
+ )
650
+ )
651
+ stride = 1
652
+ self.padding = kernel_size // 2
653
+ # conv
654
+ layers.append(
655
+ EqualConv2d(
656
+ in_channels,
657
+ out_channels,
658
+ kernel_size,
659
+ stride=stride,
660
+ padding=self.padding,
661
+ bias=bias and not activate,
662
+ )
663
+ )
664
+ # activation
665
+ if activate:
666
+ if bias:
667
+ layers.append(FusedLeakyReLU(out_channels))
668
+ else:
669
+ layers.append(ScaledLeakyReLU(0.2))
670
+
671
+ super(ConvLayer, self).__init__(*layers)
672
+
673
+
674
+ class ResBlock(nn.Module):
675
+ """Residual block used in StyleGAN2 Discriminator.
676
+ Args:
677
+ in_channels (int): Channel number of the input.
678
+ out_channels (int): Channel number of the output.
679
+ """
680
+
681
+ def __init__(self, in_channels, out_channels, interpolation_mode="bilinear"):
682
+ super(ResBlock, self).__init__()
683
+
684
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
685
+ self.conv2 = ConvLayer(
686
+ in_channels,
687
+ out_channels,
688
+ 3,
689
+ downsample=True,
690
+ interpolation_mode=interpolation_mode,
691
+ bias=True,
692
+ activate=True,
693
+ )
694
+ self.skip = ConvLayer(
695
+ in_channels,
696
+ out_channels,
697
+ 1,
698
+ downsample=True,
699
+ interpolation_mode=interpolation_mode,
700
+ bias=False,
701
+ activate=False,
702
+ )
703
+
704
+ def forward(self, x):
705
+ out = self.conv1(x)
706
+ out = self.conv2(out)
707
+ skip = self.skip(x)
708
+ out = (out + skip) / math.sqrt(2)
709
+ return out
comfy_extras/chainner_models/architecture/face/stylegan2_clean_arch.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ import math
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from torch.nn import init
9
+ from torch.nn.modules.batchnorm import _BatchNorm
10
+
11
+
12
+ @torch.no_grad()
13
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
14
+ """Initialize network weights.
15
+ Args:
16
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
17
+ scale (float): Scale initialized weights, especially for residual
18
+ blocks. Default: 1.
19
+ bias_fill (float): The value to fill bias. Default: 0
20
+ kwargs (dict): Other arguments for initialization function.
21
+ """
22
+ if not isinstance(module_list, list):
23
+ module_list = [module_list]
24
+ for module in module_list:
25
+ for m in module.modules():
26
+ if isinstance(m, nn.Conv2d):
27
+ init.kaiming_normal_(m.weight, **kwargs)
28
+ m.weight.data *= scale
29
+ if m.bias is not None:
30
+ m.bias.data.fill_(bias_fill)
31
+ elif isinstance(m, nn.Linear):
32
+ init.kaiming_normal_(m.weight, **kwargs)
33
+ m.weight.data *= scale
34
+ if m.bias is not None:
35
+ m.bias.data.fill_(bias_fill)
36
+ elif isinstance(m, _BatchNorm):
37
+ init.constant_(m.weight, 1)
38
+ if m.bias is not None:
39
+ m.bias.data.fill_(bias_fill)
40
+
41
+
42
+ class NormStyleCode(nn.Module):
43
+ def forward(self, x):
44
+ """Normalize the style codes.
45
+ Args:
46
+ x (Tensor): Style codes with shape (b, c).
47
+ Returns:
48
+ Tensor: Normalized tensor.
49
+ """
50
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
51
+
52
+
53
+ class ModulatedConv2d(nn.Module):
54
+ """Modulated Conv2d used in StyleGAN2.
55
+ There is no bias in ModulatedConv2d.
56
+ Args:
57
+ in_channels (int): Channel number of the input.
58
+ out_channels (int): Channel number of the output.
59
+ kernel_size (int): Size of the convolving kernel.
60
+ num_style_feat (int): Channel number of style features.
61
+ demodulate (bool): Whether to demodulate in the conv layer. Default: True.
62
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
63
+ eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ in_channels,
69
+ out_channels,
70
+ kernel_size,
71
+ num_style_feat,
72
+ demodulate=True,
73
+ sample_mode=None,
74
+ eps=1e-8,
75
+ ):
76
+ super(ModulatedConv2d, self).__init__()
77
+ self.in_channels = in_channels
78
+ self.out_channels = out_channels
79
+ self.kernel_size = kernel_size
80
+ self.demodulate = demodulate
81
+ self.sample_mode = sample_mode
82
+ self.eps = eps
83
+
84
+ # modulation inside each modulated conv
85
+ self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
86
+ # initialization
87
+ default_init_weights(
88
+ self.modulation,
89
+ scale=1,
90
+ bias_fill=1,
91
+ a=0,
92
+ mode="fan_in",
93
+ nonlinearity="linear",
94
+ )
95
+
96
+ self.weight = nn.Parameter(
97
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
98
+ / math.sqrt(in_channels * kernel_size**2)
99
+ )
100
+ self.padding = kernel_size // 2
101
+
102
+ def forward(self, x, style):
103
+ """Forward function.
104
+ Args:
105
+ x (Tensor): Tensor with shape (b, c, h, w).
106
+ style (Tensor): Tensor with shape (b, num_style_feat).
107
+ Returns:
108
+ Tensor: Modulated tensor after convolution.
109
+ """
110
+ b, c, h, w = x.shape # c = c_in
111
+ # weight modulation
112
+ style = self.modulation(style).view(b, 1, c, 1, 1)
113
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
114
+ weight = self.weight * style # (b, c_out, c_in, k, k)
115
+
116
+ if self.demodulate:
117
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
118
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
119
+
120
+ weight = weight.view(
121
+ b * self.out_channels, c, self.kernel_size, self.kernel_size
122
+ )
123
+
124
+ # upsample or downsample if necessary
125
+ if self.sample_mode == "upsample":
126
+ x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
127
+ elif self.sample_mode == "downsample":
128
+ x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
129
+
130
+ b, c, h, w = x.shape
131
+ x = x.view(1, b * c, h, w)
132
+ # weight: (b*c_out, c_in, k, k), groups=b
133
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
134
+ out = out.view(b, self.out_channels, *out.shape[2:4])
135
+
136
+ return out
137
+
138
+ def __repr__(self):
139
+ return (
140
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, "
141
+ f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})"
142
+ )
143
+
144
+
145
+ class StyleConv(nn.Module):
146
+ """Style conv used in StyleGAN2.
147
+ Args:
148
+ in_channels (int): Channel number of the input.
149
+ out_channels (int): Channel number of the output.
150
+ kernel_size (int): Size of the convolving kernel.
151
+ num_style_feat (int): Channel number of style features.
152
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
153
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ in_channels,
159
+ out_channels,
160
+ kernel_size,
161
+ num_style_feat,
162
+ demodulate=True,
163
+ sample_mode=None,
164
+ ):
165
+ super(StyleConv, self).__init__()
166
+ self.modulated_conv = ModulatedConv2d(
167
+ in_channels,
168
+ out_channels,
169
+ kernel_size,
170
+ num_style_feat,
171
+ demodulate=demodulate,
172
+ sample_mode=sample_mode,
173
+ )
174
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
175
+ self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
176
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
177
+
178
+ def forward(self, x, style, noise=None):
179
+ # modulate
180
+ out = self.modulated_conv(x, style) * 2**0.5 # for conversion
181
+ # noise injection
182
+ if noise is None:
183
+ b, _, h, w = out.shape
184
+ noise = out.new_empty(b, 1, h, w).normal_()
185
+ out = out + self.weight * noise
186
+ # add bias
187
+ out = out + self.bias
188
+ # activation
189
+ out = self.activate(out)
190
+ return out
191
+
192
+
193
+ class ToRGB(nn.Module):
194
+ """To RGB (image space) from features.
195
+ Args:
196
+ in_channels (int): Channel number of input.
197
+ num_style_feat (int): Channel number of style features.
198
+ upsample (bool): Whether to upsample. Default: True.
199
+ """
200
+
201
+ def __init__(self, in_channels, num_style_feat, upsample=True):
202
+ super(ToRGB, self).__init__()
203
+ self.upsample = upsample
204
+ self.modulated_conv = ModulatedConv2d(
205
+ in_channels,
206
+ 3,
207
+ kernel_size=1,
208
+ num_style_feat=num_style_feat,
209
+ demodulate=False,
210
+ sample_mode=None,
211
+ )
212
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
213
+
214
+ def forward(self, x, style, skip=None):
215
+ """Forward function.
216
+ Args:
217
+ x (Tensor): Feature tensor with shape (b, c, h, w).
218
+ style (Tensor): Tensor with shape (b, num_style_feat).
219
+ skip (Tensor): Base/skip tensor. Default: None.
220
+ Returns:
221
+ Tensor: RGB images.
222
+ """
223
+ out = self.modulated_conv(x, style)
224
+ out = out + self.bias
225
+ if skip is not None:
226
+ if self.upsample:
227
+ skip = F.interpolate(
228
+ skip, scale_factor=2, mode="bilinear", align_corners=False
229
+ )
230
+ out = out + skip
231
+ return out
232
+
233
+
234
+ class ConstantInput(nn.Module):
235
+ """Constant input.
236
+ Args:
237
+ num_channel (int): Channel number of constant input.
238
+ size (int): Spatial size of constant input.
239
+ """
240
+
241
+ def __init__(self, num_channel, size):
242
+ super(ConstantInput, self).__init__()
243
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
244
+
245
+ def forward(self, batch):
246
+ out = self.weight.repeat(batch, 1, 1, 1)
247
+ return out
248
+
249
+
250
+ class StyleGAN2GeneratorClean(nn.Module):
251
+ """Clean version of StyleGAN2 Generator.
252
+ Args:
253
+ out_size (int): The spatial size of outputs.
254
+ num_style_feat (int): Channel number of style features. Default: 512.
255
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
256
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
257
+ narrow (float): Narrow ratio for channels. Default: 1.0.
258
+ """
259
+
260
+ def __init__(
261
+ self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1
262
+ ):
263
+ super(StyleGAN2GeneratorClean, self).__init__()
264
+ # Style MLP layers
265
+ self.num_style_feat = num_style_feat
266
+ style_mlp_layers = [NormStyleCode()]
267
+ for i in range(num_mlp):
268
+ style_mlp_layers.extend(
269
+ [
270
+ nn.Linear(num_style_feat, num_style_feat, bias=True),
271
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
272
+ ]
273
+ )
274
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
275
+ # initialization
276
+ default_init_weights(
277
+ self.style_mlp,
278
+ scale=1,
279
+ bias_fill=0,
280
+ a=0.2,
281
+ mode="fan_in",
282
+ nonlinearity="leaky_relu",
283
+ )
284
+
285
+ # channel list
286
+ channels = {
287
+ "4": int(512 * narrow),
288
+ "8": int(512 * narrow),
289
+ "16": int(512 * narrow),
290
+ "32": int(512 * narrow),
291
+ "64": int(256 * channel_multiplier * narrow),
292
+ "128": int(128 * channel_multiplier * narrow),
293
+ "256": int(64 * channel_multiplier * narrow),
294
+ "512": int(32 * channel_multiplier * narrow),
295
+ "1024": int(16 * channel_multiplier * narrow),
296
+ }
297
+ self.channels = channels
298
+
299
+ self.constant_input = ConstantInput(channels["4"], size=4)
300
+ self.style_conv1 = StyleConv(
301
+ channels["4"],
302
+ channels["4"],
303
+ kernel_size=3,
304
+ num_style_feat=num_style_feat,
305
+ demodulate=True,
306
+ sample_mode=None,
307
+ )
308
+ self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False)
309
+
310
+ self.log_size = int(math.log(out_size, 2))
311
+ self.num_layers = (self.log_size - 2) * 2 + 1
312
+ self.num_latent = self.log_size * 2 - 2
313
+
314
+ self.style_convs = nn.ModuleList()
315
+ self.to_rgbs = nn.ModuleList()
316
+ self.noises = nn.Module()
317
+
318
+ in_channels = channels["4"]
319
+ # noise
320
+ for layer_idx in range(self.num_layers):
321
+ resolution = 2 ** ((layer_idx + 5) // 2)
322
+ shape = [1, 1, resolution, resolution]
323
+ self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
324
+ # style convs and to_rgbs
325
+ for i in range(3, self.log_size + 1):
326
+ out_channels = channels[f"{2**i}"]
327
+ self.style_convs.append(
328
+ StyleConv(
329
+ in_channels,
330
+ out_channels,
331
+ kernel_size=3,
332
+ num_style_feat=num_style_feat,
333
+ demodulate=True,
334
+ sample_mode="upsample",
335
+ )
336
+ )
337
+ self.style_convs.append(
338
+ StyleConv(
339
+ out_channels,
340
+ out_channels,
341
+ kernel_size=3,
342
+ num_style_feat=num_style_feat,
343
+ demodulate=True,
344
+ sample_mode=None,
345
+ )
346
+ )
347
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
348
+ in_channels = out_channels
349
+
350
+ def make_noise(self):
351
+ """Make noise for noise injection."""
352
+ device = self.constant_input.weight.device
353
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
354
+
355
+ for i in range(3, self.log_size + 1):
356
+ for _ in range(2):
357
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
358
+
359
+ return noises
360
+
361
+ def get_latent(self, x):
362
+ return self.style_mlp(x)
363
+
364
+ def mean_latent(self, num_latent):
365
+ latent_in = torch.randn(
366
+ num_latent, self.num_style_feat, device=self.constant_input.weight.device
367
+ )
368
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
369
+ return latent
370
+
371
+ def forward(
372
+ self,
373
+ styles,
374
+ input_is_latent=False,
375
+ noise=None,
376
+ randomize_noise=True,
377
+ truncation=1,
378
+ truncation_latent=None,
379
+ inject_index=None,
380
+ return_latents=False,
381
+ ):
382
+ """Forward function for StyleGAN2GeneratorClean.
383
+ Args:
384
+ styles (list[Tensor]): Sample codes of styles.
385
+ input_is_latent (bool): Whether input is latent style. Default: False.
386
+ noise (Tensor | None): Input noise or None. Default: None.
387
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
388
+ truncation (float): The truncation ratio. Default: 1.
389
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
390
+ inject_index (int | None): The injection index for mixing noise. Default: None.
391
+ return_latents (bool): Whether to return style latents. Default: False.
392
+ """
393
+ # style codes -> latents with Style MLP layer
394
+ if not input_is_latent:
395
+ styles = [self.style_mlp(s) for s in styles]
396
+ # noises
397
+ if noise is None:
398
+ if randomize_noise:
399
+ noise = [None] * self.num_layers # for each style conv layer
400
+ else: # use the stored noise
401
+ noise = [
402
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
403
+ ]
404
+ # style truncation
405
+ if truncation < 1:
406
+ style_truncation = []
407
+ for style in styles:
408
+ style_truncation.append(
409
+ truncation_latent + truncation * (style - truncation_latent)
410
+ )
411
+ styles = style_truncation
412
+ # get style latents with injection
413
+ if len(styles) == 1:
414
+ inject_index = self.num_latent
415
+
416
+ if styles[0].ndim < 3:
417
+ # repeat latent code for all the layers
418
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
419
+ else: # used for encoder with different latent code for each layer
420
+ latent = styles[0]
421
+ elif len(styles) == 2: # mixing noises
422
+ if inject_index is None:
423
+ inject_index = random.randint(1, self.num_latent - 1)
424
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
425
+ latent2 = (
426
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
427
+ )
428
+ latent = torch.cat([latent1, latent2], 1)
429
+
430
+ # main generation
431
+ out = self.constant_input(latent.shape[0])
432
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
433
+ skip = self.to_rgb1(out, latent[:, 1])
434
+
435
+ i = 1
436
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
437
+ self.style_convs[::2],
438
+ self.style_convs[1::2],
439
+ noise[1::2],
440
+ noise[2::2],
441
+ self.to_rgbs,
442
+ ):
443
+ out = conv1(out, latent[:, i], noise=noise1)
444
+ out = conv2(out, latent[:, i + 1], noise=noise2)
445
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
446
+ i += 2
447
+
448
+ image = skip
449
+
450
+ if return_latents:
451
+ return image, latent
452
+ else:
453
+ return image, None
comfy_extras/chainner_models/architecture/face/upfirdn2d.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
4
+
5
+ import os
6
+
7
+ import torch
8
+ from torch.autograd import Function
9
+ from torch.nn import functional as F
10
+
11
+ upfirdn2d_ext = None
12
+
13
+
14
+ class UpFirDn2dBackward(Function):
15
+ @staticmethod
16
+ def forward(
17
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
18
+ ):
19
+ up_x, up_y = up
20
+ down_x, down_y = down
21
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
22
+
23
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
24
+
25
+ grad_input = upfirdn2d_ext.upfirdn2d(
26
+ grad_output,
27
+ grad_kernel,
28
+ down_x,
29
+ down_y,
30
+ up_x,
31
+ up_y,
32
+ g_pad_x0,
33
+ g_pad_x1,
34
+ g_pad_y0,
35
+ g_pad_y1,
36
+ )
37
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
38
+
39
+ ctx.save_for_backward(kernel)
40
+
41
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
42
+
43
+ ctx.up_x = up_x
44
+ ctx.up_y = up_y
45
+ ctx.down_x = down_x
46
+ ctx.down_y = down_y
47
+ ctx.pad_x0 = pad_x0
48
+ ctx.pad_x1 = pad_x1
49
+ ctx.pad_y0 = pad_y0
50
+ ctx.pad_y1 = pad_y1
51
+ ctx.in_size = in_size
52
+ ctx.out_size = out_size
53
+
54
+ return grad_input
55
+
56
+ @staticmethod
57
+ def backward(ctx, gradgrad_input):
58
+ (kernel,) = ctx.saved_tensors
59
+
60
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
61
+
62
+ gradgrad_out = upfirdn2d_ext.upfirdn2d(
63
+ gradgrad_input,
64
+ kernel,
65
+ ctx.up_x,
66
+ ctx.up_y,
67
+ ctx.down_x,
68
+ ctx.down_y,
69
+ ctx.pad_x0,
70
+ ctx.pad_x1,
71
+ ctx.pad_y0,
72
+ ctx.pad_y1,
73
+ )
74
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
75
+ # ctx.out_size[1], ctx.in_size[3])
76
+ gradgrad_out = gradgrad_out.view(
77
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
78
+ )
79
+
80
+ return gradgrad_out, None, None, None, None, None, None, None, None
81
+
82
+
83
+ class UpFirDn2d(Function):
84
+ @staticmethod
85
+ def forward(ctx, input, kernel, up, down, pad):
86
+ up_x, up_y = up
87
+ down_x, down_y = down
88
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
89
+
90
+ kernel_h, kernel_w = kernel.shape
91
+ _, channel, in_h, in_w = input.shape
92
+ ctx.in_size = input.shape
93
+
94
+ input = input.reshape(-1, in_h, in_w, 1)
95
+
96
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
97
+
98
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
99
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
100
+ ctx.out_size = (out_h, out_w)
101
+
102
+ ctx.up = (up_x, up_y)
103
+ ctx.down = (down_x, down_y)
104
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
105
+
106
+ g_pad_x0 = kernel_w - pad_x0 - 1
107
+ g_pad_y0 = kernel_h - pad_y0 - 1
108
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
109
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
110
+
111
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
112
+
113
+ out = upfirdn2d_ext.upfirdn2d(
114
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
115
+ )
116
+ # out = out.view(major, out_h, out_w, minor)
117
+ out = out.view(-1, channel, out_h, out_w)
118
+
119
+ return out
120
+
121
+ @staticmethod
122
+ def backward(ctx, grad_output):
123
+ kernel, grad_kernel = ctx.saved_tensors
124
+
125
+ grad_input = UpFirDn2dBackward.apply(
126
+ grad_output,
127
+ kernel,
128
+ grad_kernel,
129
+ ctx.up,
130
+ ctx.down,
131
+ ctx.pad,
132
+ ctx.g_pad,
133
+ ctx.in_size,
134
+ ctx.out_size,
135
+ )
136
+
137
+ return grad_input, None, None, None, None
138
+
139
+
140
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
141
+ if input.device.type == "cpu":
142
+ out = upfirdn2d_native(
143
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
144
+ )
145
+ else:
146
+ out = UpFirDn2d.apply(
147
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
148
+ )
149
+
150
+ return out
151
+
152
+
153
+ def upfirdn2d_native(
154
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
155
+ ):
156
+ _, channel, in_h, in_w = input.shape
157
+ input = input.reshape(-1, in_h, in_w, 1)
158
+
159
+ _, in_h, in_w, minor = input.shape
160
+ kernel_h, kernel_w = kernel.shape
161
+
162
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
163
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
164
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
165
+
166
+ out = F.pad(
167
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
168
+ )
169
+ out = out[
170
+ :,
171
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
172
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
173
+ :,
174
+ ]
175
+
176
+ out = out.permute(0, 3, 1, 2)
177
+ out = out.reshape(
178
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
179
+ )
180
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
181
+ out = F.conv2d(out, w)
182
+ out = out.reshape(
183
+ -1,
184
+ minor,
185
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
186
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
187
+ )
188
+ out = out.permute(0, 2, 3, 1)
189
+ out = out[:, ::down_y, ::down_x, :]
190
+
191
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
192
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
193
+
194
+ return out.view(-1, channel, out_h, out_w)
comfy_extras/chainner_models/architecture/mat/utils.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Code used for this implementation of the MAT helper utils is modified from
2
+ lama-cleaner, copyright of Sanster: https://github.com/fenglinglwb/MAT"""
3
+
4
+ import collections
5
+ from itertools import repeat
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch import conv2d, conv_transpose2d
11
+
12
+
13
+ def normalize_2nd_moment(x, dim=1, eps=1e-8):
14
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
15
+
16
+
17
+ class EasyDict(dict):
18
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
19
+
20
+ def __getattr__(self, name: str) -> Any:
21
+ try:
22
+ return self[name]
23
+ except KeyError:
24
+ raise AttributeError(name)
25
+
26
+ def __setattr__(self, name: str, value: Any) -> None:
27
+ self[name] = value
28
+
29
+ def __delattr__(self, name: str) -> None:
30
+ del self[name]
31
+
32
+
33
+ activation_funcs = {
34
+ "linear": EasyDict(
35
+ func=lambda x, **_: x,
36
+ def_alpha=0,
37
+ def_gain=1,
38
+ cuda_idx=1,
39
+ ref="",
40
+ has_2nd_grad=False,
41
+ ),
42
+ "relu": EasyDict(
43
+ func=lambda x, **_: torch.nn.functional.relu(x),
44
+ def_alpha=0,
45
+ def_gain=np.sqrt(2),
46
+ cuda_idx=2,
47
+ ref="y",
48
+ has_2nd_grad=False,
49
+ ),
50
+ "lrelu": EasyDict(
51
+ func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
52
+ def_alpha=0.2,
53
+ def_gain=np.sqrt(2),
54
+ cuda_idx=3,
55
+ ref="y",
56
+ has_2nd_grad=False,
57
+ ),
58
+ "tanh": EasyDict(
59
+ func=lambda x, **_: torch.tanh(x),
60
+ def_alpha=0,
61
+ def_gain=1,
62
+ cuda_idx=4,
63
+ ref="y",
64
+ has_2nd_grad=True,
65
+ ),
66
+ "sigmoid": EasyDict(
67
+ func=lambda x, **_: torch.sigmoid(x),
68
+ def_alpha=0,
69
+ def_gain=1,
70
+ cuda_idx=5,
71
+ ref="y",
72
+ has_2nd_grad=True,
73
+ ),
74
+ "elu": EasyDict(
75
+ func=lambda x, **_: torch.nn.functional.elu(x),
76
+ def_alpha=0,
77
+ def_gain=1,
78
+ cuda_idx=6,
79
+ ref="y",
80
+ has_2nd_grad=True,
81
+ ),
82
+ "selu": EasyDict(
83
+ func=lambda x, **_: torch.nn.functional.selu(x),
84
+ def_alpha=0,
85
+ def_gain=1,
86
+ cuda_idx=7,
87
+ ref="y",
88
+ has_2nd_grad=True,
89
+ ),
90
+ "softplus": EasyDict(
91
+ func=lambda x, **_: torch.nn.functional.softplus(x),
92
+ def_alpha=0,
93
+ def_gain=1,
94
+ cuda_idx=8,
95
+ ref="y",
96
+ has_2nd_grad=True,
97
+ ),
98
+ "swish": EasyDict(
99
+ func=lambda x, **_: torch.sigmoid(x) * x,
100
+ def_alpha=0,
101
+ def_gain=np.sqrt(2),
102
+ cuda_idx=9,
103
+ ref="x",
104
+ has_2nd_grad=True,
105
+ ),
106
+ }
107
+
108
+
109
+ def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None):
110
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops."""
111
+ assert isinstance(x, torch.Tensor)
112
+ assert clamp is None or clamp >= 0
113
+ spec = activation_funcs[act]
114
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
115
+ gain = float(gain if gain is not None else spec.def_gain)
116
+ clamp = float(clamp if clamp is not None else -1)
117
+
118
+ # Add bias.
119
+ if b is not None:
120
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
121
+ assert 0 <= dim < x.ndim
122
+ assert b.shape[0] == x.shape[dim]
123
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]).to(x.device)
124
+
125
+ # Evaluate activation function.
126
+ alpha = float(alpha)
127
+ x = spec.func(x, alpha=alpha)
128
+
129
+ # Scale by gain.
130
+ gain = float(gain)
131
+ if gain != 1:
132
+ x = x * gain
133
+
134
+ # Clamp.
135
+ if clamp >= 0:
136
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
137
+ return x
138
+
139
+
140
+ def bias_act(
141
+ x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref"
142
+ ):
143
+ r"""Fused bias and activation function.
144
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
145
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
146
+ the fused op is considerably more efficient than performing the same calculation
147
+ using standard PyTorch ops. It supports first and second order gradients,
148
+ but not third order gradients.
149
+ Args:
150
+ x: Input activation tensor. Can be of any shape.
151
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
152
+ as `x`. The shape must be known, and it must match the dimension of `x`
153
+ corresponding to `dim`.
154
+ dim: The dimension in `x` corresponding to the elements of `b`.
155
+ The value of `dim` is ignored if `b` is not specified.
156
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
157
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
158
+ See `activation_funcs` for a full list. `None` is not allowed.
159
+ alpha: Shape parameter for the activation function, or `None` to use the default.
160
+ gain: Scaling factor for the output tensor, or `None` to use default.
161
+ See `activation_funcs` for the default scaling of each activation function.
162
+ If unsure, consider specifying 1.
163
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
164
+ the clamping (default).
165
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
166
+ Returns:
167
+ Tensor of the same shape and datatype as `x`.
168
+ """
169
+ assert isinstance(x, torch.Tensor)
170
+ assert impl in ["ref", "cuda"]
171
+ return _bias_act_ref(
172
+ x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp
173
+ )
174
+
175
+
176
+ def setup_filter(
177
+ f,
178
+ device=torch.device("cpu"),
179
+ normalize=True,
180
+ flip_filter=False,
181
+ gain=1,
182
+ separable=None,
183
+ ):
184
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
185
+ Args:
186
+ f: Torch tensor, numpy array, or python list of the shape
187
+ `[filter_height, filter_width]` (non-separable),
188
+ `[filter_taps]` (separable),
189
+ `[]` (impulse), or
190
+ `None` (identity).
191
+ device: Result device (default: cpu).
192
+ normalize: Normalize the filter so that it retains the magnitude
193
+ for constant input signal (DC)? (default: True).
194
+ flip_filter: Flip the filter? (default: False).
195
+ gain: Overall scaling factor for signal magnitude (default: 1).
196
+ separable: Return a separable filter? (default: select automatically).
197
+ Returns:
198
+ Float32 tensor of the shape
199
+ `[filter_height, filter_width]` (non-separable) or
200
+ `[filter_taps]` (separable).
201
+ """
202
+ # Validate.
203
+ if f is None:
204
+ f = 1
205
+ f = torch.as_tensor(f, dtype=torch.float32)
206
+ assert f.ndim in [0, 1, 2]
207
+ assert f.numel() > 0
208
+ if f.ndim == 0:
209
+ f = f[np.newaxis]
210
+
211
+ # Separable?
212
+ if separable is None:
213
+ separable = f.ndim == 1 and f.numel() >= 8
214
+ if f.ndim == 1 and not separable:
215
+ f = f.ger(f)
216
+ assert f.ndim == (1 if separable else 2)
217
+
218
+ # Apply normalize, flip, gain, and device.
219
+ if normalize:
220
+ f /= f.sum()
221
+ if flip_filter:
222
+ f = f.flip(list(range(f.ndim)))
223
+ f = f * (gain ** (f.ndim / 2))
224
+ f = f.to(device=device)
225
+ return f
226
+
227
+
228
+ def _get_filter_size(f):
229
+ if f is None:
230
+ return 1, 1
231
+
232
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
233
+ fw = f.shape[-1]
234
+ fh = f.shape[0]
235
+
236
+ fw = int(fw)
237
+ fh = int(fh)
238
+ assert fw >= 1 and fh >= 1
239
+ return fw, fh
240
+
241
+
242
+ def _get_weight_shape(w):
243
+ shape = [int(sz) for sz in w.shape]
244
+ return shape
245
+
246
+
247
+ def _parse_scaling(scaling):
248
+ if isinstance(scaling, int):
249
+ scaling = [scaling, scaling]
250
+ assert isinstance(scaling, (list, tuple))
251
+ assert all(isinstance(x, int) for x in scaling)
252
+ sx, sy = scaling
253
+ assert sx >= 1 and sy >= 1
254
+ return sx, sy
255
+
256
+
257
+ def _parse_padding(padding):
258
+ if isinstance(padding, int):
259
+ padding = [padding, padding]
260
+ assert isinstance(padding, (list, tuple))
261
+ assert all(isinstance(x, int) for x in padding)
262
+ if len(padding) == 2:
263
+ padx, pady = padding
264
+ padding = [padx, padx, pady, pady]
265
+ padx0, padx1, pady0, pady1 = padding
266
+ return padx0, padx1, pady0, pady1
267
+
268
+
269
+ def _ntuple(n):
270
+ def parse(x):
271
+ if isinstance(x, collections.abc.Iterable):
272
+ return x
273
+ return tuple(repeat(x, n))
274
+
275
+ return parse
276
+
277
+
278
+ to_2tuple = _ntuple(2)
279
+
280
+
281
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
282
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
283
+ # Validate arguments.
284
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
285
+ if f is None:
286
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
287
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
288
+ assert f.dtype == torch.float32 and not f.requires_grad
289
+ batch_size, num_channels, in_height, in_width = x.shape
290
+ # upx, upy = _parse_scaling(up)
291
+ # downx, downy = _parse_scaling(down)
292
+
293
+ upx, upy = up, up
294
+ downx, downy = down, down
295
+
296
+ # padx0, padx1, pady0, pady1 = _parse_padding(padding)
297
+ padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3]
298
+
299
+ # Upsample by inserting zeros.
300
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
301
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
302
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
303
+
304
+ # Pad or crop.
305
+ x = torch.nn.functional.pad(
306
+ x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
307
+ )
308
+ x = x[
309
+ :,
310
+ :,
311
+ max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
312
+ max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
313
+ ]
314
+
315
+ # Setup filter.
316
+ f = f * (gain ** (f.ndim / 2))
317
+ f = f.to(x.dtype)
318
+ if not flip_filter:
319
+ f = f.flip(list(range(f.ndim)))
320
+
321
+ # Convolve with the filter.
322
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
323
+ if f.ndim == 4:
324
+ x = conv2d(input=x, weight=f, groups=num_channels)
325
+ else:
326
+ x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
327
+ x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
328
+
329
+ # Downsample by throwing away pixels.
330
+ x = x[:, :, ::downy, ::downx]
331
+ return x
332
+
333
+
334
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
335
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
336
+ Performs the following sequence of operations for each channel:
337
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
338
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
339
+ Negative padding corresponds to cropping the image.
340
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
341
+ so that the footprint of all output pixels lies within the input image.
342
+ 4. Downsample the image by keeping every Nth pixel (`down`).
343
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
344
+ The fused op is considerably more efficient than performing the same calculation
345
+ using standard PyTorch ops. It supports gradients of arbitrary order.
346
+ Args:
347
+ x: Float32/float64/float16 input tensor of the shape
348
+ `[batch_size, num_channels, in_height, in_width]`.
349
+ f: Float32 FIR filter of the shape
350
+ `[filter_height, filter_width]` (non-separable),
351
+ `[filter_taps]` (separable), or
352
+ `None` (identity).
353
+ up: Integer upsampling factor. Can be a single int or a list/tuple
354
+ `[x, y]` (default: 1).
355
+ down: Integer downsampling factor. Can be a single int or a list/tuple
356
+ `[x, y]` (default: 1).
357
+ padding: Padding with respect to the upsampled image. Can be a single number
358
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
359
+ (default: 0).
360
+ flip_filter: False = convolution, True = correlation (default: False).
361
+ gain: Overall scaling factor for signal magnitude (default: 1).
362
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
363
+ Returns:
364
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
365
+ """
366
+ # assert isinstance(x, torch.Tensor)
367
+ # assert impl in ['ref', 'cuda']
368
+ return _upfirdn2d_ref(
369
+ x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
370
+ )
371
+
372
+
373
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
374
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
375
+ By default, the result is padded so that its shape is a multiple of the input.
376
+ User-specified padding is applied on top of that, with negative values
377
+ indicating cropping. Pixels outside the image are assumed to be zero.
378
+ Args:
379
+ x: Float32/float64/float16 input tensor of the shape
380
+ `[batch_size, num_channels, in_height, in_width]`.
381
+ f: Float32 FIR filter of the shape
382
+ `[filter_height, filter_width]` (non-separable),
383
+ `[filter_taps]` (separable), or
384
+ `None` (identity).
385
+ up: Integer upsampling factor. Can be a single int or a list/tuple
386
+ `[x, y]` (default: 1).
387
+ padding: Padding with respect to the output. Can be a single number or a
388
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
389
+ (default: 0).
390
+ flip_filter: False = convolution, True = correlation (default: False).
391
+ gain: Overall scaling factor for signal magnitude (default: 1).
392
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
393
+ Returns:
394
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
395
+ """
396
+ upx, upy = _parse_scaling(up)
397
+ # upx, upy = up, up
398
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
399
+ # padx0, padx1, pady0, pady1 = padding, padding, padding, padding
400
+ fw, fh = _get_filter_size(f)
401
+ p = [
402
+ padx0 + (fw + upx - 1) // 2,
403
+ padx1 + (fw - upx) // 2,
404
+ pady0 + (fh + upy - 1) // 2,
405
+ pady1 + (fh - upy) // 2,
406
+ ]
407
+ return upfirdn2d(
408
+ x,
409
+ f,
410
+ up=up,
411
+ padding=p,
412
+ flip_filter=flip_filter,
413
+ gain=gain * upx * upy,
414
+ impl=impl,
415
+ )
416
+
417
+
418
+ class FullyConnectedLayer(torch.nn.Module):
419
+ def __init__(
420
+ self,
421
+ in_features, # Number of input features.
422
+ out_features, # Number of output features.
423
+ bias=True, # Apply additive bias before the activation function?
424
+ activation="linear", # Activation function: 'relu', 'lrelu', etc.
425
+ lr_multiplier=1, # Learning rate multiplier.
426
+ bias_init=0, # Initial value for the additive bias.
427
+ ):
428
+ super().__init__()
429
+ self.weight = torch.nn.Parameter(
430
+ torch.randn([out_features, in_features]) / lr_multiplier
431
+ )
432
+ self.bias = (
433
+ torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
434
+ if bias
435
+ else None
436
+ )
437
+ self.activation = activation
438
+
439
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
440
+ self.bias_gain = lr_multiplier
441
+
442
+ def forward(self, x):
443
+ w = self.weight * self.weight_gain
444
+ b = self.bias
445
+ if b is not None and self.bias_gain != 1:
446
+ b = b * self.bias_gain
447
+
448
+ if self.activation == "linear" and b is not None:
449
+ # out = torch.addmm(b.unsqueeze(0), x, w.t())
450
+ x = x.matmul(w.t().to(x.device))
451
+ out = x + b.reshape(
452
+ [-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)]
453
+ ).to(x.device)
454
+ else:
455
+ x = x.matmul(w.t().to(x.device))
456
+ out = bias_act(x, b, act=self.activation, dim=x.ndim - 1).to(x.device)
457
+ return out
458
+
459
+
460
+ def _conv2d_wrapper(
461
+ x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True
462
+ ):
463
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations."""
464
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
465
+
466
+ # Flip weight if requested.
467
+ if (
468
+ not flip_weight
469
+ ): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
470
+ w = w.flip([2, 3])
471
+
472
+ # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
473
+ # 1x1 kernel + memory_format=channels_last + less than 64 channels.
474
+ if (
475
+ kw == 1
476
+ and kh == 1
477
+ and stride == 1
478
+ and padding in [0, [0, 0], (0, 0)]
479
+ and not transpose
480
+ ):
481
+ if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
482
+ if out_channels <= 4 and groups == 1:
483
+ in_shape = x.shape
484
+ x = w.squeeze(3).squeeze(2) @ x.reshape(
485
+ [in_shape[0], in_channels_per_group, -1]
486
+ )
487
+ x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
488
+ else:
489
+ x = x.to(memory_format=torch.contiguous_format)
490
+ w = w.to(memory_format=torch.contiguous_format)
491
+ x = conv2d(x, w, groups=groups)
492
+ return x.to(memory_format=torch.channels_last)
493
+
494
+ # Otherwise => execute using conv2d_gradfix.
495
+ op = conv_transpose2d if transpose else conv2d
496
+ return op(x, w, stride=stride, padding=padding, groups=groups)
497
+
498
+
499
+ def conv2d_resample(
500
+ x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False
501
+ ):
502
+ r"""2D convolution with optional up/downsampling.
503
+ Padding is performed only once at the beginning, not between the operations.
504
+ Args:
505
+ x: Input tensor of shape
506
+ `[batch_size, in_channels, in_height, in_width]`.
507
+ w: Weight tensor of shape
508
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
509
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
510
+ calling setup_filter(). None = identity (default).
511
+ up: Integer upsampling factor (default: 1).
512
+ down: Integer downsampling factor (default: 1).
513
+ padding: Padding with respect to the upsampled image. Can be a single number
514
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
515
+ (default: 0).
516
+ groups: Split input channels into N groups (default: 1).
517
+ flip_weight: False = convolution, True = correlation (default: True).
518
+ flip_filter: False = convolution, True = correlation (default: False).
519
+ Returns:
520
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
521
+ """
522
+ # Validate arguments.
523
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
524
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
525
+ assert f is None or (
526
+ isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32
527
+ )
528
+ assert isinstance(up, int) and (up >= 1)
529
+ assert isinstance(down, int) and (down >= 1)
530
+ # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
531
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
532
+ fw, fh = _get_filter_size(f)
533
+ # px0, px1, py0, py1 = _parse_padding(padding)
534
+ px0, px1, py0, py1 = padding, padding, padding, padding
535
+
536
+ # Adjust padding to account for up/downsampling.
537
+ if up > 1:
538
+ px0 += (fw + up - 1) // 2
539
+ px1 += (fw - up) // 2
540
+ py0 += (fh + up - 1) // 2
541
+ py1 += (fh - up) // 2
542
+ if down > 1:
543
+ px0 += (fw - down + 1) // 2
544
+ px1 += (fw - down) // 2
545
+ py0 += (fh - down + 1) // 2
546
+ py1 += (fh - down) // 2
547
+
548
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
549
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
550
+ x = upfirdn2d(
551
+ x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter
552
+ )
553
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
554
+ return x
555
+
556
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
557
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
558
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
559
+ x = upfirdn2d(
560
+ x=x,
561
+ f=f,
562
+ up=up,
563
+ padding=[px0, px1, py0, py1],
564
+ gain=up**2,
565
+ flip_filter=flip_filter,
566
+ )
567
+ return x
568
+
569
+ # Fast path: downsampling only => use strided convolution.
570
+ if down > 1 and up == 1:
571
+ x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
572
+ x = _conv2d_wrapper(
573
+ x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight
574
+ )
575
+ return x
576
+
577
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
578
+ if up > 1:
579
+ if groups == 1:
580
+ w = w.transpose(0, 1)
581
+ else:
582
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
583
+ w = w.transpose(1, 2)
584
+ w = w.reshape(
585
+ groups * in_channels_per_group, out_channels // groups, kh, kw
586
+ )
587
+ px0 -= kw - 1
588
+ px1 -= kw - up
589
+ py0 -= kh - 1
590
+ py1 -= kh - up
591
+ pxt = max(min(-px0, -px1), 0)
592
+ pyt = max(min(-py0, -py1), 0)
593
+ x = _conv2d_wrapper(
594
+ x=x,
595
+ w=w,
596
+ stride=up,
597
+ padding=[pyt, pxt],
598
+ groups=groups,
599
+ transpose=True,
600
+ flip_weight=(not flip_weight),
601
+ )
602
+ x = upfirdn2d(
603
+ x=x,
604
+ f=f,
605
+ padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
606
+ gain=up**2,
607
+ flip_filter=flip_filter,
608
+ )
609
+ if down > 1:
610
+ x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
611
+ return x
612
+
613
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
614
+ if up == 1 and down == 1:
615
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
616
+ return _conv2d_wrapper(
617
+ x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight
618
+ )
619
+
620
+ # Fallback: Generic reference implementation.
621
+ x = upfirdn2d(
622
+ x=x,
623
+ f=(f if up > 1 else None),
624
+ up=up,
625
+ padding=[px0, px1, py0, py1],
626
+ gain=up**2,
627
+ flip_filter=flip_filter,
628
+ )
629
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
630
+ if down > 1:
631
+ x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
632
+ return x
633
+
634
+
635
+ class Conv2dLayer(torch.nn.Module):
636
+ def __init__(
637
+ self,
638
+ in_channels, # Number of input channels.
639
+ out_channels, # Number of output channels.
640
+ kernel_size, # Width and height of the convolution kernel.
641
+ bias=True, # Apply additive bias before the activation function?
642
+ activation="linear", # Activation function: 'relu', 'lrelu', etc.
643
+ up=1, # Integer upsampling factor.
644
+ down=1, # Integer downsampling factor.
645
+ resample_filter=[
646
+ 1,
647
+ 3,
648
+ 3,
649
+ 1,
650
+ ], # Low-pass filter to apply when resampling activations.
651
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
652
+ channels_last=False, # Expect the input to have memory_format=channels_last?
653
+ trainable=True, # Update the weights of this layer during training?
654
+ ):
655
+ super().__init__()
656
+ self.activation = activation
657
+ self.up = up
658
+ self.down = down
659
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
660
+ self.conv_clamp = conv_clamp
661
+ self.padding = kernel_size // 2
662
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
663
+ self.act_gain = activation_funcs[activation].def_gain
664
+
665
+ memory_format = (
666
+ torch.channels_last if channels_last else torch.contiguous_format
667
+ )
668
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
669
+ memory_format=memory_format
670
+ )
671
+ bias = torch.zeros([out_channels]) if bias else None
672
+ if trainable:
673
+ self.weight = torch.nn.Parameter(weight)
674
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
675
+ else:
676
+ self.register_buffer("weight", weight)
677
+ if bias is not None:
678
+ self.register_buffer("bias", bias)
679
+ else:
680
+ self.bias = None
681
+
682
+ def forward(self, x, gain=1):
683
+ w = self.weight * self.weight_gain
684
+ x = conv2d_resample(
685
+ x=x,
686
+ w=w,
687
+ f=self.resample_filter,
688
+ up=self.up,
689
+ down=self.down,
690
+ padding=self.padding,
691
+ )
692
+
693
+ act_gain = self.act_gain * gain
694
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
695
+ out = bias_act(
696
+ x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
697
+ )
698
+ return out
comfy_extras/chainner_models/architecture/timm/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2019 Ross Wightman
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
comfy_extras/chainner_models/architecture/timm/drop.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ DropBlock, DropPath
2
+
3
+ PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
4
+
5
+ Papers:
6
+ DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
7
+
8
+ Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
9
+
10
+ Code:
11
+ DropBlock impl inspired by two Tensorflow impl that I liked:
12
+ - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
13
+ - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
14
+
15
+ Hacked together by / Copyright 2020 Ross Wightman
16
+ """
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+
22
+ def drop_block_2d(
23
+ x,
24
+ drop_prob: float = 0.1,
25
+ block_size: int = 7,
26
+ gamma_scale: float = 1.0,
27
+ with_noise: bool = False,
28
+ inplace: bool = False,
29
+ batchwise: bool = False,
30
+ ):
31
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
32
+
33
+ DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
34
+ runs with success, but needs further validation and possibly optimization for lower runtime impact.
35
+ """
36
+ _, C, H, W = x.shape
37
+ total_size = W * H
38
+ clipped_block_size = min(block_size, min(W, H))
39
+ # seed_drop_rate, the gamma parameter
40
+ gamma = (
41
+ gamma_scale
42
+ * drop_prob
43
+ * total_size
44
+ / clipped_block_size**2
45
+ / ((W - block_size + 1) * (H - block_size + 1))
46
+ )
47
+
48
+ # Forces the block to be inside the feature map.
49
+ w_i, h_i = torch.meshgrid(
50
+ torch.arange(W).to(x.device), torch.arange(H).to(x.device)
51
+ )
52
+ valid_block = (
53
+ (w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)
54
+ ) & ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
55
+ valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
56
+
57
+ if batchwise:
58
+ # one mask for whole batch, quite a bit faster
59
+ uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
60
+ else:
61
+ uniform_noise = torch.rand_like(x)
62
+ block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
63
+ block_mask = -F.max_pool2d(
64
+ -block_mask,
65
+ kernel_size=clipped_block_size, # block_size,
66
+ stride=1,
67
+ padding=clipped_block_size // 2,
68
+ )
69
+
70
+ if with_noise:
71
+ normal_noise = (
72
+ torch.randn((1, C, H, W), dtype=x.dtype, device=x.device)
73
+ if batchwise
74
+ else torch.randn_like(x)
75
+ )
76
+ if inplace:
77
+ x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
78
+ else:
79
+ x = x * block_mask + normal_noise * (1 - block_mask)
80
+ else:
81
+ normalize_scale = (
82
+ block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)
83
+ ).to(x.dtype)
84
+ if inplace:
85
+ x.mul_(block_mask * normalize_scale)
86
+ else:
87
+ x = x * block_mask * normalize_scale
88
+ return x
89
+
90
+
91
+ def drop_block_fast_2d(
92
+ x: torch.Tensor,
93
+ drop_prob: float = 0.1,
94
+ block_size: int = 7,
95
+ gamma_scale: float = 1.0,
96
+ with_noise: bool = False,
97
+ inplace: bool = False,
98
+ ):
99
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
100
+
101
+ DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
102
+ block mask at edges.
103
+ """
104
+ _, _, H, W = x.shape
105
+ total_size = W * H
106
+ clipped_block_size = min(block_size, min(W, H))
107
+ gamma = (
108
+ gamma_scale
109
+ * drop_prob
110
+ * total_size
111
+ / clipped_block_size**2
112
+ / ((W - block_size + 1) * (H - block_size + 1))
113
+ )
114
+
115
+ block_mask = torch.empty_like(x).bernoulli_(gamma)
116
+ block_mask = F.max_pool2d(
117
+ block_mask.to(x.dtype),
118
+ kernel_size=clipped_block_size,
119
+ stride=1,
120
+ padding=clipped_block_size // 2,
121
+ )
122
+
123
+ if with_noise:
124
+ normal_noise = torch.empty_like(x).normal_()
125
+ if inplace:
126
+ x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)
127
+ else:
128
+ x = x * (1.0 - block_mask) + normal_noise * block_mask
129
+ else:
130
+ block_mask = 1 - block_mask
131
+ normalize_scale = (
132
+ block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)
133
+ ).to(dtype=x.dtype)
134
+ if inplace:
135
+ x.mul_(block_mask * normalize_scale)
136
+ else:
137
+ x = x * block_mask * normalize_scale
138
+ return x
139
+
140
+
141
+ class DropBlock2d(nn.Module):
142
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf"""
143
+
144
+ def __init__(
145
+ self,
146
+ drop_prob: float = 0.1,
147
+ block_size: int = 7,
148
+ gamma_scale: float = 1.0,
149
+ with_noise: bool = False,
150
+ inplace: bool = False,
151
+ batchwise: bool = False,
152
+ fast: bool = True,
153
+ ):
154
+ super(DropBlock2d, self).__init__()
155
+ self.drop_prob = drop_prob
156
+ self.gamma_scale = gamma_scale
157
+ self.block_size = block_size
158
+ self.with_noise = with_noise
159
+ self.inplace = inplace
160
+ self.batchwise = batchwise
161
+ self.fast = fast # FIXME finish comparisons of fast vs not
162
+
163
+ def forward(self, x):
164
+ if not self.training or not self.drop_prob:
165
+ return x
166
+ if self.fast:
167
+ return drop_block_fast_2d(
168
+ x,
169
+ self.drop_prob,
170
+ self.block_size,
171
+ self.gamma_scale,
172
+ self.with_noise,
173
+ self.inplace,
174
+ )
175
+ else:
176
+ return drop_block_2d(
177
+ x,
178
+ self.drop_prob,
179
+ self.block_size,
180
+ self.gamma_scale,
181
+ self.with_noise,
182
+ self.inplace,
183
+ self.batchwise,
184
+ )
185
+
186
+
187
+ def drop_path(
188
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
189
+ ):
190
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
191
+
192
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
193
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
194
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
195
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
196
+ 'survival rate' as the argument.
197
+
198
+ """
199
+ if drop_prob == 0.0 or not training:
200
+ return x
201
+ keep_prob = 1 - drop_prob
202
+ shape = (x.shape[0],) + (1,) * (
203
+ x.ndim - 1
204
+ ) # work with diff dim tensors, not just 2D ConvNets
205
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
206
+ if keep_prob > 0.0 and scale_by_keep:
207
+ random_tensor.div_(keep_prob)
208
+ return x * random_tensor
209
+
210
+
211
+ class DropPath(nn.Module):
212
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
213
+
214
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
215
+ super(DropPath, self).__init__()
216
+ self.drop_prob = drop_prob
217
+ self.scale_by_keep = scale_by_keep
218
+
219
+ def forward(self, x):
220
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
221
+
222
+ def extra_repr(self):
223
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
comfy_extras/chainner_models/architecture/timm/helpers.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Layer/Module Helpers
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import collections.abc
5
+ from itertools import repeat
6
+
7
+
8
+ # From PyTorch internals
9
+ def _ntuple(n):
10
+ def parse(x):
11
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
12
+ return x
13
+ return tuple(repeat(x, n))
14
+
15
+ return parse
16
+
17
+
18
+ to_1tuple = _ntuple(1)
19
+ to_2tuple = _ntuple(2)
20
+ to_3tuple = _ntuple(3)
21
+ to_4tuple = _ntuple(4)
22
+ to_ntuple = _ntuple
23
+
24
+
25
+ def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
26
+ min_value = min_value or divisor
27
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
28
+ # Make sure that round down does not go down by more than 10%.
29
+ if new_v < round_limit * v:
30
+ new_v += divisor
31
+ return new_v
comfy_extras/chainner_models/architecture/timm/weight_init.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+
4
+ import torch
5
+ from torch.nn.init import _calculate_fan_in_and_fan_out
6
+
7
+
8
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
10
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11
+ def norm_cdf(x):
12
+ # Computes standard normal cumulative distribution function
13
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
14
+
15
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
16
+ warnings.warn(
17
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
18
+ "The distribution of values may be incorrect.",
19
+ stacklevel=2,
20
+ )
21
+
22
+ with torch.no_grad():
23
+ # Values are generated by using a truncated uniform distribution and
24
+ # then using the inverse CDF for the normal distribution.
25
+ # Get upper and lower cdf values
26
+ l = norm_cdf((a - mean) / std)
27
+ u = norm_cdf((b - mean) / std)
28
+
29
+ # Uniformly fill tensor with values from [l, u], then translate to
30
+ # [2l-1, 2u-1].
31
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
32
+
33
+ # Use inverse cdf transform for normal distribution to get truncated
34
+ # standard normal
35
+ tensor.erfinv_()
36
+
37
+ # Transform to proper mean, std
38
+ tensor.mul_(std * math.sqrt(2.0))
39
+ tensor.add_(mean)
40
+
41
+ # Clamp to ensure it's in the proper range
42
+ tensor.clamp_(min=a, max=b)
43
+ return tensor
44
+
45
+
46
+ def trunc_normal_(
47
+ tensor: torch.Tensor, mean=0.0, std=1.0, a=-2.0, b=2.0
48
+ ) -> torch.Tensor:
49
+ r"""Fills the input Tensor with values drawn from a truncated
50
+ normal distribution. The values are effectively drawn from the
51
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
52
+ with values outside :math:`[a, b]` redrawn until they are within
53
+ the bounds. The method used for generating the random values works
54
+ best when :math:`a \leq \text{mean} \leq b`.
55
+
56
+ NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
57
+ applied while sampling the normal with mean/std applied, therefore a, b args
58
+ should be adjusted to match the range of mean, std args.
59
+
60
+ Args:
61
+ tensor: an n-dimensional `torch.Tensor`
62
+ mean: the mean of the normal distribution
63
+ std: the standard deviation of the normal distribution
64
+ a: the minimum cutoff value
65
+ b: the maximum cutoff value
66
+ Examples:
67
+ >>> w = torch.empty(3, 5)
68
+ >>> nn.init.trunc_normal_(w)
69
+ """
70
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
71
+
72
+
73
+ def trunc_normal_tf_(
74
+ tensor: torch.Tensor, mean=0.0, std=1.0, a=-2.0, b=2.0
75
+ ) -> torch.Tensor:
76
+ r"""Fills the input Tensor with values drawn from a truncated
77
+ normal distribution. The values are effectively drawn from the
78
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
79
+ with values outside :math:`[a, b]` redrawn until they are within
80
+ the bounds. The method used for generating the random values works
81
+ best when :math:`a \leq \text{mean} \leq b`.
82
+
83
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
84
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
85
+ and the result is subsquently scaled and shifted by the mean and std args.
86
+
87
+ Args:
88
+ tensor: an n-dimensional `torch.Tensor`
89
+ mean: the mean of the normal distribution
90
+ std: the standard deviation of the normal distribution
91
+ a: the minimum cutoff value
92
+ b: the maximum cutoff value
93
+ Examples:
94
+ >>> w = torch.empty(3, 5)
95
+ >>> nn.init.trunc_normal_(w)
96
+ """
97
+ _no_grad_trunc_normal_(tensor, 0, 1.0, a, b)
98
+ with torch.no_grad():
99
+ tensor.mul_(std).add_(mean)
100
+ return tensor
101
+
102
+
103
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
104
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
105
+ if mode == "fan_in":
106
+ denom = fan_in
107
+ elif mode == "fan_out":
108
+ denom = fan_out
109
+ elif mode == "fan_avg":
110
+ denom = (fan_in + fan_out) / 2
111
+
112
+ variance = scale / denom # type: ignore
113
+
114
+ if distribution == "truncated_normal":
115
+ # constant is stddev of standard normal truncated to (-2, 2)
116
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
117
+ elif distribution == "normal":
118
+ tensor.normal_(std=math.sqrt(variance))
119
+ elif distribution == "uniform":
120
+ bound = math.sqrt(3 * variance)
121
+ # pylint: disable=invalid-unary-operand-type
122
+ tensor.uniform_(-bound, bound)
123
+ else:
124
+ raise ValueError(f"invalid distribution {distribution}")
125
+
126
+
127
+ def lecun_normal_(tensor):
128
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
comfy_extras/chainner_models/model_loading.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging as logger
2
+
3
+ from .architecture.face.codeformer import CodeFormer
4
+ from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
5
+ from .architecture.face.restoreformer_arch import RestoreFormer
6
+ from .architecture.HAT import HAT
7
+ from .architecture.LaMa import LaMa
8
+ from .architecture.MAT import MAT
9
+ from .architecture.RRDB import RRDBNet as ESRGAN
10
+ from .architecture.SPSR import SPSRNet as SPSR
11
+ from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
12
+ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
13
+ from .architecture.Swin2SR import Swin2SR
14
+ from .architecture.SwinIR import SwinIR
15
+ from .types import PyTorchModel
16
+
17
+
18
+ class UnsupportedModel(Exception):
19
+ pass
20
+
21
+
22
+ def load_state_dict(state_dict) -> PyTorchModel:
23
+ logger.debug(f"Loading state dict into pytorch model arch")
24
+
25
+ state_dict_keys = list(state_dict.keys())
26
+
27
+ if "params_ema" in state_dict_keys:
28
+ state_dict = state_dict["params_ema"]
29
+ elif "params-ema" in state_dict_keys:
30
+ state_dict = state_dict["params-ema"]
31
+ elif "params" in state_dict_keys:
32
+ state_dict = state_dict["params"]
33
+
34
+ state_dict_keys = list(state_dict.keys())
35
+ # SRVGGNet Real-ESRGAN (v2)
36
+ if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
37
+ model = RealESRGANv2(state_dict)
38
+ # SPSR (ESRGAN with lots of extra layers)
39
+ elif "f_HR_conv1.0.weight" in state_dict:
40
+ model = SPSR(state_dict)
41
+ # Swift-SRGAN
42
+ elif (
43
+ "model" in state_dict_keys
44
+ and "initial.cnn.depthwise.weight" in state_dict["model"].keys()
45
+ ):
46
+ model = SwiftSRGAN(state_dict)
47
+ # HAT -- be sure it is above swinir
48
+ elif "layers.0.residual_group.blocks.0.conv_block.cab.0.weight" in state_dict_keys:
49
+ model = HAT(state_dict)
50
+ # SwinIR
51
+ elif "layers.0.residual_group.blocks.0.norm1.weight" in state_dict_keys:
52
+ if "patch_embed.proj.weight" in state_dict_keys:
53
+ model = Swin2SR(state_dict)
54
+ else:
55
+ model = SwinIR(state_dict)
56
+ # GFPGAN
57
+ elif (
58
+ "toRGB.0.weight" in state_dict_keys
59
+ and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys
60
+ ):
61
+ model = GFPGANv1Clean(state_dict)
62
+ # RestoreFormer
63
+ elif (
64
+ "encoder.conv_in.weight" in state_dict_keys
65
+ and "encoder.down.0.block.0.norm1.weight" in state_dict_keys
66
+ ):
67
+ model = RestoreFormer(state_dict)
68
+ elif (
69
+ "encoder.blocks.0.weight" in state_dict_keys
70
+ and "quantize.embedding.weight" in state_dict_keys
71
+ ):
72
+ model = CodeFormer(state_dict)
73
+ # LaMa
74
+ elif (
75
+ "model.model.1.bn_l.running_mean" in state_dict_keys
76
+ or "generator.model.1.bn_l.running_mean" in state_dict_keys
77
+ ):
78
+ model = LaMa(state_dict)
79
+ # MAT
80
+ elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys:
81
+ model = MAT(state_dict)
82
+ # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
83
+ else:
84
+ try:
85
+ model = ESRGAN(state_dict)
86
+ except:
87
+ # pylint: disable=raise-missing-from
88
+ raise UnsupportedModel
89
+ return model
comfy_extras/chainner_models/types.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from .architecture.face.codeformer import CodeFormer
4
+ from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
5
+ from .architecture.face.restoreformer_arch import RestoreFormer
6
+ from .architecture.HAT import HAT
7
+ from .architecture.LaMa import LaMa
8
+ from .architecture.MAT import MAT
9
+ from .architecture.RRDB import RRDBNet as ESRGAN
10
+ from .architecture.SPSR import SPSRNet as SPSR
11
+ from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
12
+ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
13
+ from .architecture.Swin2SR import Swin2SR
14
+ from .architecture.SwinIR import SwinIR
15
+
16
+ PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT)
17
+ PyTorchSRModel = Union[
18
+ RealESRGANv2,
19
+ SPSR,
20
+ SwiftSRGAN,
21
+ ESRGAN,
22
+ SwinIR,
23
+ Swin2SR,
24
+ HAT,
25
+ ]
26
+
27
+
28
+ def is_pytorch_sr_model(model: object):
29
+ return isinstance(model, PyTorchSRModels)
30
+
31
+
32
+ PyTorchFaceModels = (GFPGANv1Clean, RestoreFormer, CodeFormer)
33
+ PyTorchFaceModel = Union[GFPGANv1Clean, RestoreFormer, CodeFormer]
34
+
35
+
36
+ def is_pytorch_face_model(model: object):
37
+ return isinstance(model, PyTorchFaceModels)
38
+
39
+
40
+ PyTorchInpaintModels = (LaMa, MAT)
41
+ PyTorchInpaintModel = Union[LaMa, MAT]
42
+
43
+
44
+ def is_pytorch_inpaint_model(model: object):
45
+ return isinstance(model, PyTorchInpaintModels)
46
+
47
+
48
+ PyTorchModels = (*PyTorchSRModels, *PyTorchFaceModels, *PyTorchInpaintModels)
49
+ PyTorchModel = Union[PyTorchSRModel, PyTorchFaceModel, PyTorchInpaintModel]
50
+
51
+
52
+ def is_pytorch_model(model: object):
53
+ return isinstance(model, PyTorchModels)
comfy_extras/nodes_upscale_model.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from comfy_extras.chainner_models import model_loading
3
+ import model_management
4
+ import torch
5
+ import comfy.utils
6
+ import folder_paths
7
+
8
+ class UpscaleModelLoader:
9
+ @classmethod
10
+ def INPUT_TYPES(s):
11
+ return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ),
12
+ }}
13
+ RETURN_TYPES = ("UPSCALE_MODEL",)
14
+ FUNCTION = "load_model"
15
+
16
+ CATEGORY = "loaders"
17
+
18
+ def load_model(self, model_name):
19
+ model_path = folder_paths.get_full_path("upscale_models", model_name)
20
+ sd = comfy.utils.load_torch_file(model_path)
21
+ out = model_loading.load_state_dict(sd).eval()
22
+ return (out, )
23
+
24
+
25
+ class ImageUpscaleWithModel:
26
+ @classmethod
27
+ def INPUT_TYPES(s):
28
+ return {"required": { "upscale_model": ("UPSCALE_MODEL",),
29
+ "image": ("IMAGE",),
30
+ }}
31
+ RETURN_TYPES = ("IMAGE",)
32
+ FUNCTION = "upscale"
33
+
34
+ CATEGORY = "image/upscaling"
35
+
36
+ def upscale(self, upscale_model, image):
37
+ device = model_management.get_torch_device()
38
+ upscale_model.to(device)
39
+ in_img = image.movedim(-1,-3).to(device)
40
+ s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=128 + 64, tile_y=128 + 64, overlap = 8, upscale_amount=upscale_model.scale)
41
+ upscale_model.cpu()
42
+ s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
43
+ return (s,)
44
+
45
+ NODE_CLASS_MAPPINGS = {
46
+ "UpscaleModelLoader": UpscaleModelLoader,
47
+ "ImageUpscaleWithModel": ImageUpscaleWithModel
48
+ }