Ray-1026 commited on
Commit
a856109
·
1 Parent(s): ef36a49
.gitattributes CHANGED
@@ -35,3 +35,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  assets/* filter=lfs diff=lfs merge=lfs -text
37
  assets/exp.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  assets/* filter=lfs diff=lfs merge=lfs -text
37
  assets/exp.png filter=lfs diff=lfs merge=lfs -text
38
+ weights/light_outpaint_lora filter=lfs diff=lfs merge=lfs -text
39
+ weights/light_regress filter=lfs diff=lfs merge=lfs -text
40
+ weights/net_g_last.pth filter=lfs diff=lfs merge=lfs -text
41
+ weights/light_outpaint_lora/pytorch_lora_weights.safetensors filter=lfs diff=lfs merge=lfs -text
42
+ weights/light_regress/model.pth filter=lfs diff=lfs merge=lfs -text
SIFR_models/flare7kpp/__pycache__/model.cpython-39.pyc ADDED
Binary file (52.1 kB). View file
 
SIFR_models/flare7kpp/model.py ADDED
The diff for this file is too large to render. See raw diff
 
SIFR_models/mfdnet/backbone.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class LayerNormFunction(torch.autograd.Function):
7
+
8
+ @staticmethod
9
+ def forward(ctx, x, weight, bias, eps):
10
+ ctx.eps = eps
11
+ N, C, H, W = x.size()
12
+ mu = x.mean(1, keepdim=True)
13
+ var = (x - mu).pow(2).mean(1, keepdim=True)
14
+ y = (x - mu) / (var + eps).sqrt()
15
+ ctx.save_for_backward(y, var, weight)
16
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
17
+ return y
18
+
19
+ @staticmethod
20
+ def backward(ctx, grad_output):
21
+ eps = ctx.eps
22
+
23
+ N, C, H, W = grad_output.size()
24
+ y, var, weight = ctx.saved_variables
25
+ g = grad_output * weight.view(1, C, 1, 1)
26
+ mean_g = g.mean(dim=1, keepdim=True)
27
+
28
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
29
+ gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
30
+ return (
31
+ gx,
32
+ (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
33
+ grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
34
+ None,
35
+ )
36
+
37
+
38
+ class LayerNorm2d(nn.Module):
39
+
40
+ def __init__(self, channels, eps=1e-6):
41
+ super(LayerNorm2d, self).__init__()
42
+ self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
43
+ self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
44
+ self.eps = eps
45
+
46
+ def forward(self, x):
47
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
48
+
49
+
50
+ class SimpleGate(nn.Module):
51
+ def forward(self, x):
52
+ x1, x2 = x.chunk(2, dim=1)
53
+ return x1 * x2
54
+
55
+
56
+ class NAFBlock(nn.Module):
57
+ def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0):
58
+ super().__init__()
59
+ dw_channel = c * DW_Expand
60
+ self.conv1 = nn.Conv2d(
61
+ in_channels=c,
62
+ out_channels=dw_channel,
63
+ kernel_size=1,
64
+ padding=0,
65
+ stride=1,
66
+ groups=1,
67
+ bias=True,
68
+ )
69
+ self.conv2 = nn.Conv2d(
70
+ in_channels=dw_channel,
71
+ out_channels=dw_channel,
72
+ kernel_size=3,
73
+ padding=1,
74
+ stride=1,
75
+ groups=dw_channel,
76
+ bias=True,
77
+ )
78
+ self.conv3 = nn.Conv2d(
79
+ in_channels=dw_channel // 2,
80
+ out_channels=c,
81
+ kernel_size=1,
82
+ padding=0,
83
+ stride=1,
84
+ groups=1,
85
+ bias=True,
86
+ )
87
+
88
+ # Simplified Channel Attention
89
+ self.sca = nn.Sequential(
90
+ nn.AdaptiveAvgPool2d(1),
91
+ nn.Conv2d(
92
+ in_channels=dw_channel // 2,
93
+ out_channels=dw_channel // 2,
94
+ kernel_size=1,
95
+ padding=0,
96
+ stride=1,
97
+ groups=1,
98
+ bias=True,
99
+ ),
100
+ )
101
+
102
+ # SimpleGate
103
+ self.sg = SimpleGate()
104
+
105
+ ffn_channel = FFN_Expand * c
106
+ self.conv4 = nn.Conv2d(
107
+ in_channels=c,
108
+ out_channels=ffn_channel,
109
+ kernel_size=1,
110
+ padding=0,
111
+ stride=1,
112
+ groups=1,
113
+ bias=True,
114
+ )
115
+ self.conv5 = nn.Conv2d(
116
+ in_channels=ffn_channel // 2,
117
+ out_channels=c,
118
+ kernel_size=1,
119
+ padding=0,
120
+ stride=1,
121
+ groups=1,
122
+ bias=True,
123
+ )
124
+
125
+ self.norm1 = LayerNorm2d(c)
126
+ self.norm2 = LayerNorm2d(c)
127
+
128
+ self.dropout1 = (
129
+ nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()
130
+ )
131
+ self.dropout2 = (
132
+ nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()
133
+ )
134
+
135
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
136
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
137
+
138
+ def forward(self, inp):
139
+ x = inp
140
+
141
+ x = self.norm1(x)
142
+
143
+ x = self.conv1(x)
144
+ x = self.conv2(x)
145
+ x = self.sg(x)
146
+ x = x * self.sca(x)
147
+ x = self.conv3(x)
148
+
149
+ x = self.dropout1(x)
150
+
151
+ y = inp + x * self.beta
152
+
153
+ x = self.conv4(self.norm2(y))
154
+ x = self.sg(x)
155
+ x = self.conv5(x)
156
+
157
+ x = self.dropout2(x)
158
+
159
+ return y + x * self.gamma
160
+
161
+
162
+ class NAFNet(nn.Module):
163
+
164
+ def __init__(
165
+ self,
166
+ img_channel=3,
167
+ width=32,
168
+ middle_blk_num=12,
169
+ enc_blk_nums=[2, 2, 4, 8],
170
+ dec_blk_nums=[2, 2, 2, 2],
171
+ ):
172
+ super().__init__()
173
+
174
+ self.intro = nn.Conv2d(
175
+ in_channels=img_channel,
176
+ out_channels=width,
177
+ kernel_size=3,
178
+ padding=1,
179
+ stride=1,
180
+ groups=1,
181
+ bias=True,
182
+ )
183
+ self.ending = nn.Conv2d(
184
+ in_channels=width,
185
+ out_channels=img_channel,
186
+ kernel_size=3,
187
+ padding=1,
188
+ stride=1,
189
+ groups=1,
190
+ bias=True,
191
+ )
192
+
193
+ self.encoders = nn.ModuleList()
194
+ self.decoders = nn.ModuleList()
195
+ self.middle_blks = nn.ModuleList()
196
+ self.ups = nn.ModuleList()
197
+ self.downs = nn.ModuleList()
198
+
199
+ chan = width
200
+ for num in enc_blk_nums:
201
+ self.encoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
202
+ self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2))
203
+ chan = chan * 2
204
+
205
+ self.middle_blks = nn.Sequential(
206
+ *[NAFBlock(chan) for _ in range(middle_blk_num)]
207
+ )
208
+
209
+ for num in dec_blk_nums:
210
+ self.ups.append(
211
+ nn.Sequential(
212
+ nn.Conv2d(chan, chan * 2, 1, bias=False), nn.PixelShuffle(2)
213
+ )
214
+ )
215
+ chan = chan // 2
216
+ self.decoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
217
+
218
+ self.padder_size = 2 ** len(self.encoders)
219
+
220
+ def forward(self, inp):
221
+ B, C, H, W = inp.shape
222
+ inp = self.check_image_size(inp)
223
+
224
+ x = self.intro(inp)
225
+
226
+ encs = []
227
+
228
+ for encoder, down in zip(self.encoders, self.downs):
229
+ x = encoder(x)
230
+ encs.append(x)
231
+ x = down(x)
232
+
233
+ x = self.middle_blks(x)
234
+
235
+ for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
236
+ x = up(x)
237
+ x = x + enc_skip
238
+ x = decoder(x)
239
+
240
+ x = self.ending(x)
241
+ x = x + inp
242
+
243
+ return x[:, :, :H, :W]
244
+
245
+ def check_image_size(self, x):
246
+ _, _, h, w = x.size()
247
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
248
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
249
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
250
+ return x
251
+
252
+
253
+ if __name__ == "__main__":
254
+ img_channel = 3
255
+ width = 32
256
+
257
+ enc_blks = [2, 2, 4, 8]
258
+ middle_blk_num = 12
259
+ dec_blks = [2, 2, 2, 2]
260
+
261
+ print(
262
+ "enc blks",
263
+ enc_blks,
264
+ "middle blk num",
265
+ middle_blk_num,
266
+ "dec blks",
267
+ dec_blks,
268
+ "width",
269
+ width,
270
+ )
271
+
272
+ # using('start . ')
273
+ model = NAFNet(
274
+ img_channel=img_channel,
275
+ width=width,
276
+ middle_blk_num=middle_blk_num,
277
+ enc_blk_nums=enc_blks,
278
+ dec_blk_nums=dec_blks,
279
+ ).cuda()
280
+
281
+ model.eval()
282
+ input = torch.randn(1, 3, 15, 22).cuda()
283
+ # input = torch.randn(1, 3, 32, 32)
284
+ y = model(input)
285
+ print(y.size())
SIFR_models/mfdnet/blocks.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ConvLayer(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_channels,
10
+ out_channels,
11
+ kernel_size,
12
+ stride,
13
+ dilation=1,
14
+ bias=True,
15
+ groups=1,
16
+ norm="in",
17
+ nonlinear="relu",
18
+ ):
19
+ super(ConvLayer, self).__init__()
20
+ reflection_padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2
21
+ self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
22
+ self.conv2d = nn.Conv2d(
23
+ in_channels,
24
+ out_channels,
25
+ kernel_size,
26
+ stride,
27
+ groups=groups,
28
+ bias=bias,
29
+ dilation=dilation,
30
+ )
31
+ self.norm = norm
32
+ self.nonlinear = nonlinear
33
+
34
+ if norm == "bn":
35
+ self.normalization = nn.BatchNorm2d(out_channels)
36
+ elif norm == "in":
37
+ self.normalization = nn.InstanceNorm2d(out_channels, affine=False)
38
+ else:
39
+ self.normalization = None
40
+
41
+ if nonlinear == "relu":
42
+ self.activation = nn.ReLU(inplace=True)
43
+ elif nonlinear == "leakyrelu":
44
+ self.activation = nn.LeakyReLU(0.2)
45
+ elif nonlinear == "PReLU":
46
+ self.activation = nn.PReLU()
47
+ else:
48
+ self.activation = None
49
+
50
+ def forward(self, x):
51
+ out = self.conv2d(self.reflection_pad(x))
52
+ if self.normalization is not None:
53
+ out = self.normalization(out)
54
+ if self.activation is not None:
55
+ out = self.activation(out)
56
+
57
+ return out
58
+
59
+
60
+ class Aggreation(nn.Module):
61
+ def __init__(self, in_channels, out_channels, kernel_size=3):
62
+ super(Aggreation, self).__init__()
63
+ self.attention = SelfAttention(in_channels, k=8, nonlinear="relu")
64
+ self.conv = ConvLayer(
65
+ in_channels,
66
+ out_channels,
67
+ kernel_size=kernel_size,
68
+ stride=1,
69
+ dilation=1,
70
+ nonlinear="leakyrelu",
71
+ norm=None,
72
+ )
73
+
74
+ def forward(self, x):
75
+ return self.conv(self.attention(x))
76
+
77
+
78
+ class SelfAttention(nn.Module):
79
+ def __init__(self, channels, k, nonlinear="relu"):
80
+ super(SelfAttention, self).__init__()
81
+ self.channels = channels
82
+ self.k = k
83
+ self.nonlinear = nonlinear
84
+
85
+ self.linear1 = nn.Linear(channels, channels // k)
86
+ self.linear2 = nn.Linear(channels // k, channels)
87
+ self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
88
+
89
+ if nonlinear == "relu":
90
+ self.activation = nn.ReLU(inplace=True)
91
+ elif nonlinear == "leakyrelu":
92
+ self.activation = nn.LeakyReLU(0.2)
93
+ elif nonlinear == "PReLU":
94
+ self.activation = nn.PReLU()
95
+ else:
96
+ raise ValueError
97
+
98
+ def attention(self, x):
99
+ N, C, H, W = x.size()
100
+ out = torch.flatten(self.global_pooling(x), 1)
101
+ out = self.activation(self.linear1(out))
102
+ out = torch.sigmoid(self.linear2(out)).view(N, C, 1, 1)
103
+
104
+ return out.mul(x)
105
+
106
+ def forward(self, x):
107
+ return self.attention(x)
108
+
109
+
110
+ class SPP(nn.Module):
111
+ def __init__(
112
+ self, in_channels, out_channels, num_layers=4, interpolation_type="bilinear"
113
+ ):
114
+ super(SPP, self).__init__()
115
+ self.conv = nn.ModuleList()
116
+ self.num_layers = num_layers
117
+ self.interpolation_type = interpolation_type
118
+
119
+ for _ in range(self.num_layers):
120
+ self.conv.append(
121
+ ConvLayer(
122
+ in_channels,
123
+ in_channels,
124
+ kernel_size=1,
125
+ stride=1,
126
+ dilation=1,
127
+ nonlinear="leakyrelu",
128
+ norm=None,
129
+ )
130
+ )
131
+
132
+ self.fusion = ConvLayer(
133
+ (in_channels * (self.num_layers + 1)),
134
+ out_channels,
135
+ kernel_size=3,
136
+ stride=1,
137
+ norm="False",
138
+ nonlinear="leakyrelu",
139
+ )
140
+
141
+ def forward(self, x):
142
+
143
+ N, C, H, W = x.size()
144
+ out = []
145
+
146
+ for level in range(self.num_layers):
147
+ out.append(
148
+ F.interpolate(
149
+ self.conv[level](
150
+ F.avg_pool2d(
151
+ x,
152
+ kernel_size=2 * 2 ** (level + 1),
153
+ stride=2 * 2 ** (level + 1),
154
+ padding=2 * 2 ** (level + 1) % 2,
155
+ )
156
+ ),
157
+ size=(H, W),
158
+ mode=self.interpolation_type,
159
+ )
160
+ )
161
+
162
+ out.append(x)
163
+
164
+ return self.fusion(torch.cat(out, dim=1))
SIFR_models/mfdnet/model.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+
3
+ import einops
4
+ from einops import rearrange
5
+
6
+ from .backbone import *
7
+ from .blocks import *
8
+
9
+
10
+ class ResidualBlock(nn.Module):
11
+ def __init__(self, in_features):
12
+ super(ResidualBlock, self).__init__()
13
+
14
+ self.block = nn.Sequential(
15
+ nn.Conv2d(in_features, in_features, 3, padding=1),
16
+ nn.LeakyReLU(),
17
+ nn.Conv2d(in_features, in_features, 3, padding=1),
18
+ )
19
+
20
+ def forward(self, x):
21
+ return x + self.block(x)
22
+
23
+
24
+ def gauss_kernel(channels=3):
25
+ kernel = torch.tensor(
26
+ [
27
+ [1.0, 4.0, 6.0, 4.0, 1],
28
+ [4.0, 16.0, 24.0, 16.0, 4.0],
29
+ [6.0, 24.0, 36.0, 24.0, 6.0],
30
+ [4.0, 16.0, 24.0, 16.0, 4.0],
31
+ [1.0, 4.0, 6.0, 4.0, 1.0],
32
+ ]
33
+ )
34
+ kernel /= 256.0
35
+ kernel = kernel.repeat(channels, 1, 1, 1)
36
+ return kernel
37
+
38
+
39
+ class LapPyramidConv(nn.Module):
40
+ def __init__(self, num_high=4):
41
+ super(LapPyramidConv, self).__init__()
42
+
43
+ self.num_high = num_high
44
+ self.kernel = gauss_kernel()
45
+
46
+ def downsample(self, x):
47
+ return x[:, :, ::2, ::2]
48
+
49
+ def upsample(self, x):
50
+ cc = torch.cat(
51
+ [
52
+ x,
53
+ torch.zeros(
54
+ x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device
55
+ ),
56
+ ],
57
+ dim=3,
58
+ )
59
+ cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
60
+ cc = cc.permute(0, 1, 3, 2)
61
+ cc = torch.cat(
62
+ [
63
+ cc,
64
+ torch.zeros(
65
+ x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device
66
+ ),
67
+ ],
68
+ dim=3,
69
+ )
70
+ cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
71
+ x_up = cc.permute(0, 1, 3, 2)
72
+ return self.conv_gauss(x_up, 4 * self.kernel)
73
+
74
+ def conv_gauss(self, img, kernel):
75
+ # 对最后两个维度进行填充,(左右上下)
76
+ img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode="reflect")
77
+ # 分组卷积
78
+ out = torch.nn.functional.conv2d(
79
+ img, kernel.to(img.device), groups=img.shape[1]
80
+ )
81
+ return out
82
+
83
+ def pyramid_decom(self, img):
84
+ current = img
85
+ pyr = []
86
+ for _ in range(self.num_high):
87
+ filtered = self.conv_gauss(current, self.kernel)
88
+ down = self.downsample(filtered)
89
+ up = self.upsample(down)
90
+ if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:
91
+ up = nn.functional.interpolate(
92
+ up, size=(current.shape[2], current.shape[3])
93
+ )
94
+ diff = current - up
95
+ pyr.append(diff)
96
+ current = down
97
+ pyr.append(current)
98
+ return pyr
99
+
100
+ def pyramid_recons(self, pyr):
101
+ image = pyr[-1]
102
+ for level in reversed(pyr[:-1]):
103
+ up = self.upsample(image)
104
+ if up.shape[2] != level.shape[2] or up.shape[3] != level.shape[3]:
105
+ up = nn.functional.interpolate(
106
+ up, size=(level.shape[2], level.shape[3])
107
+ )
108
+ image = up + level
109
+ return image
110
+
111
+
112
+ class TransHigh(nn.Module):
113
+ def __init__(self, num_residual_blocks, num_high=3):
114
+ super(TransHigh, self).__init__()
115
+
116
+ self.num_high = num_high
117
+
118
+ blocks = [nn.Conv2d(9, 64, 3, padding=1), nn.LeakyReLU()]
119
+
120
+ for _ in range(num_residual_blocks):
121
+ blocks += [ResidualBlock(64)]
122
+
123
+ blocks += [nn.Conv2d(64, 3, 3, padding=1)]
124
+
125
+ self.model = nn.Sequential(*blocks)
126
+
127
+ channels = 3
128
+ # Stage1
129
+ self.block1_1 = ConvLayer(
130
+ in_channels=channels,
131
+ out_channels=channels,
132
+ kernel_size=3,
133
+ stride=1,
134
+ dilation=2,
135
+ norm=None,
136
+ nonlinear="leakyrelu",
137
+ )
138
+ self.block1_2 = ConvLayer(
139
+ in_channels=channels,
140
+ out_channels=channels,
141
+ kernel_size=3,
142
+ stride=1,
143
+ dilation=4,
144
+ norm=None,
145
+ nonlinear="leakyrelu",
146
+ )
147
+ self.aggreation1_rgb = Aggreation(
148
+ in_channels=channels * 3, out_channels=channels
149
+ )
150
+ # Stage2
151
+ self.block2_1 = ConvLayer(
152
+ in_channels=channels,
153
+ out_channels=channels,
154
+ kernel_size=3,
155
+ stride=1,
156
+ dilation=8,
157
+ norm=None,
158
+ nonlinear="leakyrelu",
159
+ )
160
+ self.block2_2 = ConvLayer(
161
+ in_channels=channels,
162
+ out_channels=channels,
163
+ kernel_size=3,
164
+ stride=1,
165
+ dilation=16,
166
+ norm=None,
167
+ nonlinear="leakyrelu",
168
+ )
169
+ self.aggreation2_rgb = Aggreation(
170
+ in_channels=channels * 3, out_channels=channels
171
+ )
172
+ # Stage3
173
+ self.block3_1 = ConvLayer(
174
+ in_channels=channels,
175
+ out_channels=channels,
176
+ kernel_size=3,
177
+ stride=1,
178
+ dilation=32,
179
+ norm=None,
180
+ nonlinear="leakyrelu",
181
+ )
182
+ self.block3_2 = ConvLayer(
183
+ in_channels=channels,
184
+ out_channels=channels,
185
+ kernel_size=3,
186
+ stride=1,
187
+ dilation=64,
188
+ norm=None,
189
+ nonlinear="leakyrelu",
190
+ )
191
+ self.aggreation3_rgb = Aggreation(
192
+ in_channels=channels * 3, out_channels=channels
193
+ )
194
+ # self.block_3 = NAFNet(middle_blk_num=2, enc_blk_nums=[
195
+ # 1,1], dec_blk_nums=[1,1])
196
+ self.trans_mask_block_1 = nn.Sequential(
197
+ nn.Conv2d(3, 16, 1), nn.LeakyReLU(), nn.Conv2d(16, 3, 1)
198
+ )
199
+ self.trans_mask_block_2 = nn.Sequential(
200
+ nn.Conv2d(3, 16, 1), nn.LeakyReLU(), nn.Conv2d(16, 3, 1)
201
+ )
202
+
203
+ # self.trans_mask_block = NAFNet(
204
+ # middle_blk_num=1, enc_blk_nums=[1], dec_blk_nums=[1])
205
+ # Stage3
206
+ self.spp_img = SPP(
207
+ in_channels=channels,
208
+ out_channels=channels,
209
+ num_layers=4,
210
+ interpolation_type="bicubic",
211
+ )
212
+ self.block4_1 = nn.Conv2d(
213
+ in_channels=channels, out_channels=3, kernel_size=1, stride=1
214
+ )
215
+
216
+ def forward(self, x, pyr_original, fake_low):
217
+ pyr_result = [fake_low]
218
+ mask = self.model(x)
219
+
220
+ mask = nn.functional.interpolate(
221
+ mask, size=(pyr_original[-2].shape[2], pyr_original[-2].shape[3])
222
+ )
223
+ mask = self.trans_mask_block_1(mask)
224
+ result_highfreq = torch.mul(pyr_original[-2], mask) + pyr_original[-2]
225
+
226
+ # result_highfreq = self.block_3(result_highfreq)
227
+ out1_1 = self.block1_1(result_highfreq)
228
+ out1_2 = self.block1_2(out1_1)
229
+ agg1_rgb = self.aggreation1_rgb(
230
+ torch.cat((result_highfreq, out1_1, out1_2), dim=1)
231
+ )
232
+ pyr_result.append(agg1_rgb)
233
+
234
+ mask = nn.functional.interpolate(
235
+ mask, size=(pyr_original[-3].shape[2], pyr_original[-3].shape[3])
236
+ )
237
+ mask = self.trans_mask_block_2(mask)
238
+ result_highfreq = torch.mul(pyr_original[-3], mask) + pyr_original[-3]
239
+
240
+ # result_highfreq = self.block_3(result_highfreq)
241
+ out2_1 = self.block2_1(result_highfreq)
242
+ out2_2 = self.block2_2(out2_1)
243
+ agg2_rgb = self.aggreation2_rgb(
244
+ torch.cat((result_highfreq, out2_1, out2_2), dim=1)
245
+ )
246
+
247
+ out3_1 = self.block3_1(agg2_rgb)
248
+ out3_2 = self.block3_2(out3_1)
249
+ agg3_rgb = self.aggreation3_rgb(torch.cat((agg2_rgb, out3_1, out3_2), dim=1))
250
+
251
+ spp_rgb = self.spp_img(agg3_rgb)
252
+ out_rgb = self.block4_1(spp_rgb)
253
+
254
+ pyr_result.append(out_rgb)
255
+ pyr_result.reverse()
256
+
257
+ return pyr_result
258
+
259
+
260
+ # Layer Norm
261
+
262
+
263
+ def to_3d(x):
264
+ return rearrange(x, "b c h w -> b (h w) c")
265
+
266
+
267
+ def to_4d(x, h, w):
268
+ return rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
269
+
270
+
271
+ class BiasFree_LayerNorm(nn.Module):
272
+ def __init__(self, normalized_shape):
273
+ super(BiasFree_LayerNorm, self).__init__()
274
+ if isinstance(normalized_shape, numbers.Integral):
275
+ normalized_shape = (normalized_shape,)
276
+ normalized_shape = torch.Size(normalized_shape)
277
+
278
+ assert len(normalized_shape) == 1
279
+
280
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
281
+ self.normalized_shape = normalized_shape
282
+
283
+ def forward(self, x):
284
+ sigma = x.var(-1, keepdim=True, unbiased=False)
285
+ return x / torch.sqrt(sigma + 1e-5) * self.weight
286
+
287
+
288
+ class WithBias_LayerNorm(nn.Module):
289
+ def __init__(self, normalized_shape):
290
+ super(WithBias_LayerNorm, self).__init__()
291
+ if isinstance(normalized_shape, numbers.Integral):
292
+ normalized_shape = (normalized_shape,)
293
+ normalized_shape = torch.Size(normalized_shape)
294
+
295
+ assert len(normalized_shape) == 1
296
+
297
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
298
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
299
+ self.normalized_shape = normalized_shape
300
+
301
+ def forward(self, x):
302
+ mu = x.mean(-1, keepdim=True)
303
+ sigma = x.var(-1, keepdim=True, unbiased=False)
304
+ return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
305
+
306
+
307
+ class LayerNorm(nn.Module):
308
+ def __init__(self, dim, LayerNorm_type):
309
+ super(LayerNorm, self).__init__()
310
+ if LayerNorm_type == "BiasFree":
311
+ self.body = BiasFree_LayerNorm(dim)
312
+ else:
313
+ self.body = WithBias_LayerNorm(dim)
314
+
315
+ def forward(self, x):
316
+ h, w = x.shape[-2:]
317
+ return to_4d(self.body(to_3d(x)), h, w)
318
+
319
+
320
+ # Axis-based Multi-head Self-Attention
321
+
322
+
323
+ class NextAttentionImplZ(nn.Module):
324
+ def __init__(self, num_dims, num_heads, bias) -> None:
325
+ super().__init__()
326
+ self.num_dims = num_dims
327
+ self.num_heads = num_heads
328
+ self.q1 = nn.Conv2d(num_dims, num_dims * 3, kernel_size=1, bias=bias)
329
+ self.q2 = nn.Conv2d(
330
+ num_dims * 3,
331
+ num_dims * 3,
332
+ kernel_size=3,
333
+ padding=1,
334
+ groups=num_dims * 3,
335
+ bias=bias,
336
+ )
337
+ self.q3 = nn.Conv2d(
338
+ num_dims * 3,
339
+ num_dims * 3,
340
+ kernel_size=3,
341
+ padding=1,
342
+ groups=num_dims * 3,
343
+ bias=bias,
344
+ )
345
+
346
+ self.fac = nn.Parameter(torch.ones(1))
347
+ self.fin = nn.Conv2d(num_dims, num_dims, kernel_size=1, bias=bias)
348
+ return
349
+
350
+ def forward(self, x):
351
+ # x: [n, c, h, w]
352
+ n, c, h, w = x.size()
353
+ n_heads, dim_head = self.num_heads, c // self.num_heads
354
+
355
+ def reshape(x):
356
+ return einops.rearrange(
357
+ x, "n (nh dh) h w -> (n nh h) w dh", nh=n_heads, dh=dim_head
358
+ )
359
+
360
+ qkv = self.q3(self.q2(self.q1(x)))
361
+ q, k, v = map(reshape, qkv.chunk(3, dim=1))
362
+ q = F.normalize(q, dim=-1)
363
+ k = F.normalize(k, dim=-1)
364
+
365
+ # fac = dim_head ** -0.5
366
+ res = k.transpose(-2, -1)
367
+ res = torch.matmul(q, res) * self.fac
368
+ res = torch.softmax(res, dim=-1)
369
+
370
+ res = torch.matmul(res, v)
371
+ res = einops.rearrange(
372
+ res, "(n nh h) w dh -> n (nh dh) h w", nh=n_heads, dh=dim_head, n=n, h=h
373
+ )
374
+ res = self.fin(res)
375
+
376
+ return res
377
+
378
+
379
+ # Axis-based Multi-head Self-Attention (row and col attention)
380
+ class NextAttentionZ(nn.Module):
381
+ def __init__(self, num_dims, num_heads=1, bias=True) -> None:
382
+ super().__init__()
383
+ assert num_dims % num_heads == 0
384
+ self.num_dims = num_dims
385
+ self.num_heads = num_heads
386
+ self.row_att = NextAttentionImplZ(num_dims, num_heads, bias)
387
+ self.col_att = NextAttentionImplZ(num_dims, num_heads, bias)
388
+ return
389
+
390
+ def forward(self, x: torch.Tensor):
391
+ assert len(x.size()) == 4
392
+
393
+ x = self.row_att(x)
394
+ x = x.transpose(-2, -1)
395
+ x = self.col_att(x)
396
+ x = x.transpose(-2, -1)
397
+
398
+ return x
399
+
400
+
401
+ # Dual Gated Feed-Forward Networ
402
+ class FeedForward(nn.Module):
403
+ def __init__(self, dim, ffn_expansion_factor, bias):
404
+ super(FeedForward, self).__init__()
405
+
406
+ hidden_features = int(dim * ffn_expansion_factor)
407
+
408
+ self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
409
+
410
+ self.dwconv = nn.Conv2d(
411
+ hidden_features * 2,
412
+ hidden_features * 2,
413
+ kernel_size=3,
414
+ stride=1,
415
+ padding=1,
416
+ groups=hidden_features * 2,
417
+ bias=bias,
418
+ )
419
+
420
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
421
+
422
+ def forward(self, x):
423
+ x = self.project_in(x)
424
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
425
+ x = F.gelu(x2) * x1 + F.gelu(x1) * x2
426
+ x = self.project_out(x)
427
+ return x
428
+
429
+
430
+ # Axis-based Transformer Block
431
+ class TransformerBlock(nn.Module):
432
+ def __init__(
433
+ self,
434
+ dim,
435
+ num_heads=1,
436
+ ffn_expansion_factor=2.66,
437
+ bias=True,
438
+ LayerNorm_type="WithBias",
439
+ ):
440
+ super(TransformerBlock, self).__init__()
441
+
442
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
443
+ self.attn = NextAttentionZ(dim, num_heads)
444
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
445
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
446
+
447
+ def forward(self, x):
448
+ x = x + self.attn(self.norm1(x))
449
+ x = x + self.ffn(self.norm2(x))
450
+ return x
451
+
452
+
453
+ ##########################################################################
454
+ # Overlapped image patch embedding with 3x3 Conv
455
+ class OverlapPatchEmbed(nn.Module):
456
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
457
+ super(OverlapPatchEmbed, self).__init__()
458
+
459
+ self.proj = nn.Conv2d(
460
+ in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias
461
+ )
462
+
463
+ def forward(self, x):
464
+ x = self.proj(x)
465
+
466
+ return x
467
+
468
+
469
+ ##########################################################################
470
+ # Resizing modules
471
+ class Downsample(nn.Module):
472
+ def __init__(self, n_feat):
473
+ super(Downsample, self).__init__()
474
+
475
+ self.body = nn.Sequential(
476
+ nn.Conv2d(
477
+ n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False
478
+ ),
479
+ nn.PixelUnshuffle(2),
480
+ )
481
+
482
+ def forward(self, x):
483
+ return self.body(x)
484
+
485
+
486
+ class Upsample(nn.Module):
487
+ def __init__(self, n_feat):
488
+ super(Upsample, self).__init__()
489
+
490
+ self.body = nn.Sequential(
491
+ nn.Conv2d(
492
+ n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False
493
+ ),
494
+ nn.PixelShuffle(2),
495
+ )
496
+
497
+ def forward(self, x):
498
+ return self.body(x)
499
+
500
+
501
+ # Cross-layer Attention Fusion Block
502
+ class LAM_Module_v2(nn.Module):
503
+ """Layer attention module"""
504
+
505
+ def __init__(self, in_dim, bias=True):
506
+ super(LAM_Module_v2, self).__init__()
507
+ self.chanel_in = in_dim
508
+
509
+ self.temperature = nn.Parameter(torch.ones(1))
510
+
511
+ self.qkv = nn.Conv2d(
512
+ self.chanel_in, self.chanel_in * 3, kernel_size=1, bias=bias
513
+ )
514
+ self.qkv_dwconv = nn.Conv2d(
515
+ self.chanel_in * 3,
516
+ self.chanel_in * 3,
517
+ kernel_size=3,
518
+ stride=1,
519
+ padding=1,
520
+ groups=self.chanel_in * 3,
521
+ bias=bias,
522
+ )
523
+ self.project_out = nn.Conv2d(
524
+ self.chanel_in, self.chanel_in, kernel_size=1, bias=bias
525
+ )
526
+
527
+ def forward(self, x):
528
+ """
529
+ inputs :
530
+ x : input feature maps( B X N X C X H X W)
531
+ returns :
532
+ out : attention value + input feature
533
+ attention: B X N X N
534
+ """
535
+ m_batchsize, N, C, height, width = x.size()
536
+
537
+ x_input = x.view(m_batchsize, N * C, height, width)
538
+ qkv = self.qkv_dwconv(self.qkv(x_input))
539
+ q, k, v = qkv.chunk(3, dim=1)
540
+ q = q.view(m_batchsize, N, -1)
541
+ k = k.view(m_batchsize, N, -1)
542
+ v = v.view(m_batchsize, N, -1)
543
+
544
+ q = torch.nn.functional.normalize(q, dim=-1)
545
+ k = torch.nn.functional.normalize(k, dim=-1)
546
+
547
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
548
+ attn = attn.softmax(dim=-1)
549
+
550
+ out_1 = attn @ v
551
+ out_1 = out_1.view(m_batchsize, -1, height, width)
552
+
553
+ out_1 = self.project_out(out_1)
554
+ out_1 = out_1.view(m_batchsize, N, C, height, width)
555
+
556
+ out = out_1 + x
557
+ out = out.view(m_batchsize, -1, height, width)
558
+ return out
559
+
560
+
561
+ ##########################################################################
562
+ # ---------- LLFormer -----------------------
563
+ class Backbone(nn.Module):
564
+ def __init__(
565
+ self,
566
+ inp_channels=3,
567
+ out_channels=3,
568
+ dim=3,
569
+ num_blocks=[1, 2, 4, 8],
570
+ num_refinement_blocks=1,
571
+ heads=[1, 2, 4, 8],
572
+ ffn_expansion_factor=2.66,
573
+ bias=False,
574
+ LayerNorm_type="WithBias",
575
+ attention=True,
576
+ ):
577
+ super(Backbone, self).__init__()
578
+
579
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
580
+
581
+ self.encoder_1 = nn.Sequential(
582
+ *[
583
+ TransformerBlock(
584
+ dim=dim,
585
+ num_heads=heads[0],
586
+ ffn_expansion_factor=ffn_expansion_factor,
587
+ bias=bias,
588
+ LayerNorm_type=LayerNorm_type,
589
+ )
590
+ for _ in range(num_blocks[0])
591
+ ]
592
+ )
593
+
594
+ self.encoder_2 = nn.Sequential(
595
+ *[
596
+ TransformerBlock(
597
+ dim=int(dim),
598
+ num_heads=heads[0],
599
+ ffn_expansion_factor=ffn_expansion_factor,
600
+ bias=bias,
601
+ LayerNorm_type=LayerNorm_type,
602
+ )
603
+ for _ in range(num_blocks[0])
604
+ ]
605
+ )
606
+
607
+ self.encoder_3 = nn.Sequential(
608
+ *[
609
+ TransformerBlock(
610
+ dim=int(dim),
611
+ num_heads=heads[0],
612
+ ffn_expansion_factor=ffn_expansion_factor,
613
+ bias=bias,
614
+ LayerNorm_type=LayerNorm_type,
615
+ )
616
+ for _ in range(num_blocks[0])
617
+ ]
618
+ )
619
+
620
+ self.layer_fussion = LAM_Module_v2(in_dim=int(dim * 3))
621
+ self.conv_fuss = nn.Conv2d(int(dim * 3), int(dim), kernel_size=1, bias=bias)
622
+
623
+ # self.latent = nn.Sequential(*[
624
+ # TransformerBlock(dim=int(dim), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
625
+ # LayerNorm_type=LayerNorm_type) for _ in range(num_blocks[0])])
626
+
627
+ # self.trans_low = NAFNet()
628
+
629
+ # self.coefficient_1_0 = nn.Parameter(torch.ones(
630
+ # (2, int(int(dim)))), requires_grad=attention)
631
+
632
+ self.latent_1 = nn.Sequential(
633
+ *[
634
+ TransformerBlock(
635
+ dim=int(dim),
636
+ num_heads=heads[0],
637
+ ffn_expansion_factor=ffn_expansion_factor,
638
+ bias=bias,
639
+ LayerNorm_type=LayerNorm_type,
640
+ )
641
+ for _ in range(num_blocks[0])
642
+ ]
643
+ )
644
+ """
645
+ self.latent_2 = nn.Sequential(*[
646
+ TransformerBlock(dim=int(dim), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
647
+ LayerNorm_type=LayerNorm_type) for _ in range(num_blocks[0])])
648
+ """
649
+ self.trans_low_1 = NAFNet(
650
+ middle_blk_num=10, enc_blk_nums=[1, 2, 4], dec_blk_nums=[4, 2, 1]
651
+ )
652
+ # self.trans_low_2 = NAFNet()
653
+
654
+ self.coefficient_1_0 = nn.Parameter(
655
+ torch.ones((2, int(int(dim)))), requires_grad=attention
656
+ )
657
+
658
+ # self.coefficient_2_0 = nn.Parameter(torch.ones(
659
+ # (2, int(int(dim)))), requires_grad=attention)
660
+
661
+ self.refinement_1 = nn.Sequential(
662
+ *[
663
+ TransformerBlock(
664
+ dim=int(dim),
665
+ num_heads=heads[0],
666
+ ffn_expansion_factor=ffn_expansion_factor,
667
+ bias=bias,
668
+ LayerNorm_type=LayerNorm_type,
669
+ )
670
+ for _ in range(num_refinement_blocks)
671
+ ]
672
+ )
673
+ self.refinement_2 = nn.Sequential(
674
+ *[
675
+ TransformerBlock(
676
+ dim=int(dim),
677
+ num_heads=heads[0],
678
+ ffn_expansion_factor=ffn_expansion_factor,
679
+ bias=bias,
680
+ LayerNorm_type=LayerNorm_type,
681
+ )
682
+ for _ in range(num_refinement_blocks)
683
+ ]
684
+ )
685
+ self.refinement_3 = nn.Sequential(
686
+ *[
687
+ TransformerBlock(
688
+ dim=int(dim),
689
+ num_heads=heads[0],
690
+ ffn_expansion_factor=ffn_expansion_factor,
691
+ bias=bias,
692
+ LayerNorm_type=LayerNorm_type,
693
+ )
694
+ for _ in range(num_refinement_blocks)
695
+ ]
696
+ )
697
+
698
+ self.layer_fussion_2 = LAM_Module_v2(in_dim=int(dim * 3))
699
+ self.conv_fuss_2 = nn.Conv2d(int(dim * 3), int(dim), kernel_size=1, bias=bias)
700
+
701
+ self.output = nn.Conv2d(
702
+ int(dim), out_channels, kernel_size=3, stride=1, padding=1, bias=bias
703
+ )
704
+
705
+ def forward(self, inp):
706
+ inp_enc_encoder1 = self.patch_embed(inp)
707
+ out_enc_encoder1 = self.encoder_1(inp_enc_encoder1)
708
+ out_enc_encoder2 = self.encoder_2(out_enc_encoder1)
709
+ out_enc_encoder3 = self.encoder_3(out_enc_encoder2)
710
+
711
+ inp_fusion_123 = torch.cat(
712
+ [
713
+ out_enc_encoder1.unsqueeze(1),
714
+ out_enc_encoder2.unsqueeze(1),
715
+ out_enc_encoder3.unsqueeze(1),
716
+ ],
717
+ dim=1,
718
+ )
719
+
720
+ out_fusion_123 = self.layer_fussion(inp_fusion_123)
721
+ out_fusion_123 = self.conv_fuss(out_fusion_123)
722
+
723
+ # out_enc = self.trans_low(out_fusion_123)
724
+
725
+ # out_fusion_123 = self.latent(out_fusion_123)
726
+
727
+ # out = self.coefficient_1_0[0, :][None, :, None, None] * out_fusion_123 + self.coefficient_1_0[1, :][None, :,None, None] * out_enc
728
+
729
+ out_enc_1 = self.trans_low_1(out_fusion_123)
730
+
731
+ out_fusion_123_1 = self.latent_1(out_fusion_123)
732
+
733
+ out = (
734
+ self.coefficient_1_0[0, :][None, :, None, None] * out_fusion_123_1
735
+ + self.coefficient_1_0[1, :][None, :, None, None] * out_enc_1
736
+ )
737
+ # out_enc_2 = self.trans_low_2(out)
738
+
739
+ # out_fusion_123_2 = self.latent_2(out)
740
+
741
+ # out = self.coefficient_2_0[0, :][None, :, None, None] * out_fusion_123_2 + self.coefficient_2_0[1, :][None, :,None, None] * out_enc_2
742
+ out_1 = self.refinement_1(out)
743
+ out_2 = self.refinement_2(out_1)
744
+ out_3 = self.refinement_3(out_2)
745
+
746
+ inp_fusion = torch.cat(
747
+ [out_1.unsqueeze(1), out_2.unsqueeze(1), out_3.unsqueeze(1)], dim=1
748
+ )
749
+ out_fusion_123 = self.layer_fussion_2(inp_fusion)
750
+ out = self.conv_fuss_2(out_fusion_123)
751
+ result = self.output(out)
752
+
753
+ return result
754
+
755
+
756
+ class Model(nn.Module):
757
+ def __init__(self, depth=2):
758
+ super(Model, self).__init__()
759
+ self.backbone = Backbone()
760
+ self.lap_pyramid = LapPyramidConv(depth)
761
+ self.trans_high = TransHigh(3, num_high=depth)
762
+
763
+ def forward(self, inp):
764
+ pyr_inp = self.lap_pyramid.pyramid_decom(img=inp)
765
+ out_low = self.backbone(pyr_inp[-1])
766
+
767
+ inp_up = nn.functional.interpolate(
768
+ pyr_inp[-1], size=(pyr_inp[-2].shape[2], pyr_inp[-2].shape[3])
769
+ )
770
+ out_up = nn.functional.interpolate(
771
+ out_low, size=(pyr_inp[-2].shape[2], pyr_inp[-2].shape[3])
772
+ )
773
+ high_with_low = torch.cat([pyr_inp[-2], inp_up, out_up], 1)
774
+
775
+ pyr_inp_trans = self.trans_high(high_with_low, pyr_inp, out_low)
776
+
777
+ result = self.lap_pyramid.pyramid_recons(pyr_inp_trans)
778
+
779
+ return result
780
+
781
+
782
+ if __name__ == "__main__":
783
+ tensor = torch.randn(1, 3, 1024, 1024).cuda()
784
+ model = Model().cuda()
785
+ output = model(tensor)
786
+ print(output.shape)
app.py CHANGED
@@ -2,13 +2,30 @@ import numpy as np
2
  import gradio as gr
3
  import numpy as np
4
  import random
5
-
6
- # import torch
7
  import spaces
8
  import os
9
  import base64
10
  import json
 
11
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  intro = """
14
  <div style="text-align:center">
@@ -142,19 +159,44 @@ def encode_image(pil_image):
142
  # raise Exception(f"Failed to post: {response}")
143
 
144
 
145
- # # --- Model Loading ---
146
- # dtype = torch.bfloat16
147
- device = "cpu"
148
-
149
- # # Load the model pipeline
150
- # pipe = QwenImageEditPipeline.from_pretrained(
151
- # "Qwen/Qwen-Image-Edit", torch_dtype=dtype
152
- # ).to(device)
153
- # pipe.transformer.__class__ = QwenImageTransformer2DModel
154
- # pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
155
-
156
- # # --- Ahead-of-time compilation ---
157
- # optimize_pipeline_(pipe, image=Image.new("RGB", (1024, 1024)), prompt="prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  # --- UI Constants and Helpers ---
160
  MAX_SEED = np.iinfo(np.int32).max
@@ -164,47 +206,137 @@ MAX_SEED = np.iinfo(np.int32).max
164
  @spaces.GPU(duration=120)
165
  def infer(
166
  image,
167
- seed=120,
168
- true_guidance_scale=4.0,
169
  num_inference_steps=50,
 
 
 
 
170
  progress=gr.Progress(track_tqdm=True),
171
  ):
172
  """
173
- Generates an image using the local Qwen-Image diffusers pipeline.
174
  """
175
- # Hardcode the negative prompt as requested
176
- negative_prompt = " "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- # if randomize_seed:
179
- # seed = 42
180
- # seed = random.randint(0, MAX_SEED)
181
 
182
- # Set up the generator for reproducibility
183
- # generator = torch.Generator(device=device).manual_seed(seed)
184
 
185
- # print(f"Calling pipeline with prompt: '{prompt}'")
186
- print(f"Negative Prompt: '{negative_prompt}'")
187
- print(
188
- f"Seed: {seed}, Steps: {num_inference_steps}, Guidance: {true_guidance_scale}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  )
190
- # if rewrite_prompt:
191
- # # prompt = polish_prompt(prompt, image)
192
- # print(f"Rewritten Prompt: {prompt}")
193
 
194
- # Generate the image
195
- # images = pipe(
196
- # image,
197
- # prompt=prompt,
198
- # negative_prompt=negative_prompt,
199
- # num_inference_steps=num_inference_steps,
200
- # generator=generator,
201
- # true_cfg_scale=true_guidance_scale,
202
- # num_images_per_prompt=1,
203
- # ).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- images = [Image.open("exp.png")]
206
 
207
- return images[0], images[0], seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
 
210
  # --- Examples and UI Layout ---
@@ -243,6 +375,20 @@ with gr.Blocks(css=css) as demo:
243
  value=42,
244
  )
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
247
 
248
  with gr.Row():
@@ -306,8 +452,12 @@ with gr.Blocks(css=css) as demo:
306
  seed,
307
  true_guidance_scale,
308
  num_inference_steps,
 
 
 
 
309
  ],
310
- outputs=[outpainted_result, flarefree_result, seed],
311
  )
312
 
313
  if __name__ == "__main__":
 
2
  import gradio as gr
3
  import numpy as np
4
  import random
5
+ import torch
 
6
  import spaces
7
  import os
8
  import base64
9
  import json
10
+ import torchvision
11
  from PIL import Image
12
+ from diffusers import ControlNetModel, DPMSolverMultistepScheduler
13
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
14
+
15
+ from src.pipelines.pipeline_stable_diffusion_outpaint import OutpaintPipeline
16
+ from src.pipelines.pipeline_controlnet_outpaint import ControlNetOutpaintPipeline
17
+ from src.schedulers.scheduling_pndm import CustomScheduler
18
+ from src.models.unet import U_Net
19
+ from src.models.light_source_regressor import LightSourceRegressor
20
+ from utils.dataset import HFCustomImageLoader
21
+ from utils.utils import (
22
+ blend_with_alpha,
23
+ load_mfdnet_checkpoint,
24
+ predict_flare_from_6_channel,
25
+ predict_flare_from_3_channel,
26
+ blend_light_source,
27
+ )
28
+ from SIFR_models.flare7kpp.model import Uformer
29
 
30
  intro = """
31
  <div style="text-align:center">
 
159
  # raise Exception(f"Failed to post: {response}")
160
 
161
 
162
+ ## --- Model Loading --- ##
163
+ device = "cuda" if torch.cuda.is_available() else "cpu"
164
+ dtype = torch.float16 if device == "cuda" else torch.float32
165
+ print(f"Using device: {device}")
166
+
167
+ # controlnet
168
+ controlnet = ControlNetModel.from_pretrained(
169
+ "RayTsai-030/LightsOut-controlnet", torch_dtype=dtype
170
+ )
171
+
172
+ # outpainter
173
+ pipe = ControlNetOutpaintPipeline.from_pretrained(
174
+ "stabilityai/stable-diffusion-2-inpainting", controlnet=controlnet, torch_dtype=dtype
175
+ ).to(device)
176
+ pipe.scheduler = CustomScheduler.from_config(pipe.scheduler.config)
177
+ pipe.unet.load_attn_procs("./weights/light_outpaint_lora", use_safetensors=True)
178
+
179
+ # blip
180
+ processor = Blip2Processor.from_pretrained(
181
+ "Salesforce/blip2-opt-2.7b", revision="51572668da0eb669e01a189dc22abe6088589a24"
182
+ )
183
+ blip2 = Blip2ForConditionalGeneration.from_pretrained(
184
+ "Salesforce/blip2-opt-2.7b",
185
+ torch_dtype=dtype,
186
+ revision="51572668da0eb669e01a189dc22abe6088589a24",
187
+ )
188
+ blip2 = blip2.to(device)
189
+
190
+ # light regressor
191
+ lsr_module = LightSourceRegressor()
192
+ ckpt = torch.load("./weights/light_regress/model.pth")
193
+ lsr_module.load_state_dict(ckpt["model"])
194
+ lsr_module.to(device)
195
+ lsr_module.eval()
196
+
197
+ # SIFR model
198
+ sifr_model = Uformer(img_size=512, img_ch=3, output_ch=6).to(device)
199
+ sifr_model.load_state_dict(torch.load("./weights/net_g_last.pth"))
200
 
201
  # --- UI Constants and Helpers ---
202
  MAX_SEED = np.iinfo(np.int32).max
 
206
  @spaces.GPU(duration=120)
207
  def infer(
208
  image,
209
+ seed=42,
210
+ cfg=7.5,
211
  num_inference_steps=50,
212
+ left_outpaint=64,
213
+ right_outpaint=64,
214
+ up_outpaint=64,
215
+ down_outpaint=64,
216
  progress=gr.Progress(track_tqdm=True),
217
  ):
218
  """
219
+ Generates an image
220
  """
221
+ # dataset
222
+ dataset = HFCustomImageLoader(image, left_outpaint, right_outpaint, up_outpaint, down_outpaint)
223
+ data = dataset[0]
224
+
225
+ # generator
226
+ generator = torch.Generator(device=device).manual_seed(seed)
227
+
228
+ # transformation
229
+ transform = torchvision.transforms.Compose(
230
+ [
231
+ torchvision.transforms.ToTensor(),
232
+ torchvision.transforms.Normalize(mean=[0.5], std=[0.5]),
233
+ ]
234
+ )
235
+ sifr_transform = torchvision.transforms.Compose(
236
+ [
237
+ torchvision.transforms.ToTensor(),
238
+ torchvision.transforms.Resize((512, 512)),
239
+ ]
240
+ )
241
+
242
+ threshold = 0.5
243
+
244
+ with torch.no_grad():
245
+ input_img = data["input_img"]
246
+
247
+ input_img = transform(input_img).unsqueeze(0).to(device)
248
 
249
+ pred_mask = lsr_module.forward_render(input_img)
 
 
250
 
251
+ pred_mask = (pred_mask > threshold).float()
 
252
 
253
+ if pred_mask.device != "cpu":
254
+ pred_mask = pred_mask.cpu()
255
+ pred_mask = pred_mask.numpy()
256
+
257
+ data["control_img"] = Image.fromarray(
258
+ (pred_mask[0, 0] * 255).astype(np.uint8)
259
+ )
260
+
261
+ # print("Finish light source detection...")
262
+
263
+ # prepare text prompt
264
+ inputs = processor(data["blip_img"], return_tensors="pt").to(
265
+ device=device, dtype=dtype
266
+ )
267
+ generate_id = blip2.generate(**inputs, max_new_tokens=20)
268
+ generated_text = processor.batch_decode(generate_id, skip_special_tokens=True)[
269
+ 0
270
+ ].strip()
271
+
272
+ generated_text += (
273
+ ", dynamic lighting, intense light source, prominent lens flare, best quality, high resolution, masterpiece, intricate details"
274
+ # ", full light sources with lens flare, best quality, high resolution"
275
  )
 
 
 
276
 
277
+ # print(f"Generated text prompt: {generated_text}")
278
+
279
+ # Blur mask
280
+ # data["mask_img"] = data["mask_img"].filter(ImageFilter.GaussianBlur(15))
281
+
282
+ # denoise
283
+ outpaint_result = pipe(
284
+ prompt=generated_text,
285
+ negative_prompt="NSFW, (word:1.5), watermark, blurry, missing body, amputation, mutilation",
286
+ image=data["input_img"],
287
+ mask_image=data["mask_img"],
288
+ control_image=data["control_img"],
289
+ num_inference_steps=num_inference_steps,
290
+ guidance_scale=cfg,
291
+ generator=generator,
292
+ repeat_time=4,
293
+ ).images[0]
294
+
295
+ # save result
296
+ outpaint_result = np.array(outpaint_result)
297
+ input_img = np.array(data["input_img"])
298
+ box = data["box"]
299
+
300
+ input_img2 = outpaint_result.copy()
301
+ input_img2[box[2] : box[3] + 1, box[0] : box[1] + 1] = input_img[
302
+ box[2] : box[3] + 1, box[0] : box[1] + 1
303
+ ]
304
 
305
+ outpaint_result = blend_with_alpha(outpaint_result, input_img2, box, blur_size=31)
306
 
307
+ outpaint_result = Image.fromarray(outpaint_result.astype(np.uint8))
308
+
309
+ # print("Finish outpainting...")
310
+
311
+ # flare removal
312
+ img = sifr_transform(outpaint_result).unsqueeze(0).cuda()
313
+
314
+ with torch.no_grad():
315
+ output_img = sifr_model(img)
316
+
317
+ gamma = torch.Tensor([2.2])
318
+
319
+ # flare7k++
320
+ deflare_result, _, _ = predict_flare_from_6_channel(output_img, gamma)
321
+
322
+ # # mfdnet
323
+ # flare_mask = torch.zeros_like(img)
324
+ # deflare_img, _ = predict_flare_from_3_channel(
325
+ # output_img, flare_mask, output_img, img, img, gamma
326
+ # )
327
+ # deflare_img = blend_light_source(img, deflare_img, 0.999)
328
+
329
+ if deflare_result.device != "cpu":
330
+ deflare_result = deflare_result.cpu()
331
+ deflare_result = deflare_result.squeeze(0).permute(1, 2, 0).numpy()
332
+ deflare_result = np.clip(deflare_result, 0.0, 1.0)
333
+ deflare_result = (deflare_result * 255).astype(np.uint8)
334
+ deflare_result = deflare_result[box[2] : box[3] + 1, box[0] : box[1] + 1, :]
335
+ deflare_result = Image.fromarray(deflare_result).resize((512, 512), Image.LANCZOS)
336
+
337
+ # print("Finish flare removal...")
338
+
339
+ return outpaint_result, deflare_result
340
 
341
 
342
  # --- Examples and UI Layout ---
 
375
  value=42,
376
  )
377
 
378
+ with gr.Column():
379
+ left_outpaint = gr.Slider(
380
+ label="Left outpaint (px)", minimum=0, maximum=128, step=1, value=64
381
+ )
382
+ right_outpaint = gr.Slider(
383
+ label="Right outpaint (px)", minimum=0, maximum=128, step=1, value=64
384
+ )
385
+ up_outpaint = gr.Slider(
386
+ label="Up outpaint (px)", minimum=0, maximum=128, step=1, value=64
387
+ )
388
+ down_outpaint = gr.Slider(
389
+ label="Down outpaint (px)", minimum=0, maximum=128, step=1, value=64
390
+ )
391
+
392
  # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
393
 
394
  with gr.Row():
 
452
  seed,
453
  true_guidance_scale,
454
  num_inference_steps,
455
+ left_outpaint,
456
+ right_outpaint,
457
+ up_outpaint,
458
+ down_outpaint,
459
  ],
460
+ outputs=[outpainted_result, flarefree_result],
461
  )
462
 
463
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,2 +1,15 @@
1
- gradio
2
- pydantic==2.10.6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.1
2
+ pydantic==2.10.6
3
+ accelerate==0.21.0
4
+ diffusers==0.23.0
5
+ einops==0.8.0
6
+ huggingface-hub==0.25.2
7
+ imageio==2.36.0
8
+ numpy==1.24.1
9
+ opencv-python==4.10.0.84
10
+ scikit-image==0.24.0
11
+ timm==1.0.11
12
+ transformers==4.36.0
13
+ xformers==0.0.20
14
+ spaces
15
+ pillow
src/models/__pycache__/light_source_regressor.cpython-39.pyc ADDED
Binary file (3.41 kB). View file
 
src/models/__pycache__/unet.cpython-39.pyc ADDED
Binary file (3.75 kB). View file
 
src/models/light_source_regressor.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import init
4
+ from torchvision.models import resnet34, resnet50
5
+ import torchvision.models.vision_transformer as vit
6
+
7
+
8
+ class LightSourceRegressor(nn.Module):
9
+ def __init__(self, num_lights=4, alpha=2.0, beta=8.0, **kwargs):
10
+ super(LightSourceRegressor, self).__init__()
11
+
12
+ self.num_lights = num_lights
13
+ self.alpha = alpha
14
+ self.beta = beta
15
+
16
+ self.model = resnet34(pretrained=True)
17
+ # self.model = resnet50(pretrained=True)
18
+ # self.model = vit.vit_b_16(pretrained=True)
19
+ self.init_resnet()
20
+ # self.init_vit()
21
+
22
+ self.xyr_mlp = nn.Sequential(
23
+ nn.Linear(self.last_dim, 3 * self.num_lights),
24
+ )
25
+ self.p_mlp = nn.Sequential(
26
+ nn.Linear(self.last_dim, self.num_lights),
27
+ nn.Sigmoid(), # ensure p is in [0, 1]
28
+ )
29
+
30
+ def init_resnet(self):
31
+ self.last_dim = self.model.fc.in_features
32
+ self.model.fc = nn.Identity()
33
+
34
+ def init_vit(self):
35
+ self.model.image_size = 512
36
+ old_pos_embed = self.model.encoder.pos_embedding
37
+ num_patches_old = (224 // 16) ** 2
38
+ num_patches_new = (512 // 16) ** 2
39
+
40
+ if num_patches_new != num_patches_old:
41
+ old_pos_embed = old_pos_embed[:, 1:]
42
+ old_pos_embed = nn.functional.interpolate(
43
+ old_pos_embed.permute(0, 2, 1), size=(num_patches_new,), mode="linear"
44
+ )
45
+ old_pos_embed = old_pos_embed.permute(0, 2, 1)
46
+
47
+ # new positional embedding
48
+ self.model.encoder.pos_embedding = nn.Parameter(
49
+ torch.cat(
50
+ [self.model.encoder.pos_embedding[:, :1], old_pos_embed], dim=1
51
+ )
52
+ )
53
+
54
+ # num_classes = 4 * self.num_lights # x, y, r, p
55
+ # self.model.heads.head = nn.Linear(self.model.hidden_dim, num_classes)
56
+
57
+ # remove the head
58
+ self.last_dim = self.model.hidden_dim
59
+ self.model.heads.head = nn.Identity()
60
+
61
+ def forward(self, x, height=512, width=512, smoothness=0.1, merge=False):
62
+ _x = self.model(x) # [B, last_dim]
63
+
64
+ _xyr = self.xyr_mlp(_x)
65
+ _xyr = _xyr.view(-1, self.num_lights, 3)
66
+
67
+ _p = self.p_mlp(_x)
68
+ _p = _p.view(-1, self.num_lights)
69
+
70
+ output = torch.cat([_xyr, _p.unsqueeze(-1)], dim=-1)
71
+
72
+ return output
73
+
74
+ def forward_render(self, x, height=512, width=512, smoothness=0.1, merge=False):
75
+ _x = self.forward(x)
76
+
77
+ _xy = _x[:, :, :2]
78
+ _r = _x[:, :, 2]
79
+ _p = _x[:, :, 3]
80
+
81
+ masks = None
82
+ masks_merge = None
83
+ for b in range(_x.size(0)):
84
+ x, y, r = _xy[b, :, 0] * width, _xy[b, :, 1] * width, _r[b] * width / 2
85
+ p = _p[b]
86
+
87
+ mask_list = []
88
+ for i in range(self.num_lights):
89
+ if r[i] < 0 or r[i] > width or p[i] < 0.5:
90
+ continue
91
+
92
+ y_coords, x_coords = torch.meshgrid(
93
+ torch.arange(height, device=x.device),
94
+ torch.arange(width, device=x.device),
95
+ indexing="ij",
96
+ )
97
+
98
+ distances = torch.sqrt((x_coords - x[i]) ** 2 + (y_coords - y[i]) ** 2)
99
+ mask_i = torch.sigmoid(smoothness * (r[i] - distances))
100
+ mask_list.append(mask_i)
101
+
102
+ if len(mask_list) == 0:
103
+ _mask_merge = torch.zeros(1, 1, height, width, device=x.device)
104
+ else:
105
+ _mask_merge = torch.stack(mask_list, dim=0).sum(dim=0).unsqueeze(0)
106
+ _mask_merge = _mask_merge.unsqueeze(0)
107
+
108
+ masks_merge = (
109
+ _mask_merge
110
+ if masks_merge is None
111
+ else torch.cat([masks_merge, _mask_merge], dim=0)
112
+ )
113
+
114
+ masks_merge = torch.clamp(masks_merge, 0, 1)
115
+
116
+ return masks_merge # [B, 1, H, W]
117
+
118
+
119
+ if __name__ == "__main__":
120
+ # pydiffvg.set_use_gpu(torch.cuda.is_available())
121
+ model = LightSourceRegressor(num_lights=4).cuda()
122
+ x = torch.randn(8, 3, 512, 512, device="cuda")
123
+ y = model.forward_render(x)
124
+ print(y.shape)
src/models/unet.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torch.nn import init
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class conv_block(nn.Module):
8
+ def __init__(self, ch_in, ch_out):
9
+ super(conv_block, self).__init__()
10
+ self.conv = nn.Sequential(
11
+ nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
12
+ nn.BatchNorm2d(ch_out),
13
+ nn.ReLU(inplace=True),
14
+ nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
15
+ nn.BatchNorm2d(ch_out),
16
+ nn.ReLU(inplace=True),
17
+ )
18
+
19
+ def forward(self, x):
20
+ x = self.conv(x)
21
+ return x
22
+
23
+
24
+ class up_conv(nn.Module):
25
+ def __init__(self, ch_in, ch_out):
26
+ super(up_conv, self).__init__()
27
+ self.up = nn.Sequential(
28
+ nn.Upsample(scale_factor=2),
29
+ nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
30
+ nn.BatchNorm2d(ch_out),
31
+ nn.ReLU(inplace=True),
32
+ )
33
+
34
+ def forward(self, x):
35
+ x = self.up(x)
36
+ return x
37
+
38
+
39
+ class U_Net(nn.Module):
40
+ def __init__(self, img_ch=3, output_ch=1, multi_stage=False):
41
+ super(U_Net, self).__init__()
42
+
43
+ self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
44
+
45
+ self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
46
+ self.Conv2 = conv_block(ch_in=64, ch_out=128)
47
+ self.Conv3 = conv_block(ch_in=128, ch_out=256)
48
+ self.Conv4 = conv_block(ch_in=256, ch_out=512)
49
+ self.Conv5 = conv_block(ch_in=512, ch_out=1024)
50
+
51
+ self.Up5 = up_conv(ch_in=1024, ch_out=512)
52
+ self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
53
+
54
+ self.Up4 = up_conv(ch_in=512, ch_out=256)
55
+ self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
56
+
57
+ self.Up3 = up_conv(ch_in=256, ch_out=128)
58
+ self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
59
+
60
+ self.Up2 = up_conv(ch_in=128, ch_out=64)
61
+ self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
62
+
63
+ self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
64
+ self.activation = nn.Sequential(nn.Sigmoid())
65
+ # init_weights(self)
66
+ self.apply(self._init_weights)
67
+
68
+ def _init_weights(self, m):
69
+ init_type = "normal"
70
+ gain = 0.02
71
+ classname = m.__class__.__name__
72
+ if hasattr(m, "weight") and (
73
+ classname.find("Conv") != -1 or classname.find("Linear") != -1
74
+ ):
75
+ if init_type == "normal":
76
+ init.normal_(m.weight.data, 0.0, gain)
77
+ elif init_type == "xavier":
78
+ init.xavier_normal_(m.weight.data, gain=gain)
79
+ elif init_type == "kaiming":
80
+ init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
81
+ elif init_type == "orthogonal":
82
+ init.orthogonal_(m.weight.data, gain=gain)
83
+ else:
84
+ raise NotImplementedError(
85
+ "initialization method [%s] is not implemented" % init_type
86
+ )
87
+ if hasattr(m, "bias") and m.bias is not None:
88
+ init.constant_(m.bias.data, 0.0)
89
+ elif classname.find("BatchNorm2d") != -1:
90
+ init.normal_(m.weight.data, 1.0, gain)
91
+ init.constant_(m.bias.data, 0.0)
92
+
93
+ def forward(self, x):
94
+ # encoding path
95
+ x1 = self.Conv1(x)
96
+
97
+ x2 = self.Maxpool(x1)
98
+ x2 = self.Conv2(x2)
99
+
100
+ x3 = self.Maxpool(x2)
101
+ x3 = self.Conv3(x3)
102
+
103
+ x4 = self.Maxpool(x3)
104
+ x4 = self.Conv4(x4)
105
+
106
+ x5 = self.Maxpool(x4)
107
+ x5 = self.Conv5(x5)
108
+
109
+ # decoding + concat path
110
+ d5 = self.Up5(x5)
111
+ d5 = torch.cat((x4, d5), dim=1)
112
+
113
+ d5 = self.Up_conv5(d5)
114
+
115
+ d4 = self.Up4(d5)
116
+ d4 = torch.cat((x3, d4), dim=1)
117
+ d4 = self.Up_conv4(d4)
118
+
119
+ d3 = self.Up3(d4)
120
+ d3 = torch.cat((x2, d3), dim=1)
121
+ d3 = self.Up_conv3(d3)
122
+
123
+ d2 = self.Up2(d3)
124
+ d2 = torch.cat((x1, d2), dim=1)
125
+ d2 = self.Up_conv2(d2)
126
+
127
+ d1 = self.Conv_1x1(d2)
128
+ d1 = self.activation(d1)
129
+ return d1
src/pipelines/__pycache__/pipeline_controlnet_outpaint.cpython-39.pyc ADDED
Binary file (7.49 kB). View file
 
src/pipelines/__pycache__/pipeline_stable_diffusion_outpaint.cpython-39.pyc ADDED
Binary file (16 kB). View file
 
src/pipelines/pipeline_controlnet_outpaint.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import List, Union, Dict, Any, Callable, Optional, Tuple
4
+ from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel
5
+ from diffusers.utils.torch_utils import randn_tensor, is_compiled_module
6
+ from diffusers.models import ControlNetModel
7
+ from diffusers.pipelines.controlnet import MultiControlNetModel
8
+ from diffusers.image_processor import PipelineImageInput
9
+ from diffusers.pipelines.stable_diffusion.pipeline_output import (
10
+ StableDiffusionPipelineOutput,
11
+ )
12
+
13
+
14
+ class ControlNetOutpaintPipeline(StableDiffusionControlNetInpaintPipeline):
15
+ @torch.no_grad()
16
+ def __call__(
17
+ self,
18
+ prompt: Union[str, List[str]] = None,
19
+ image: PipelineImageInput = None,
20
+ mask_image: PipelineImageInput = None,
21
+ control_image: PipelineImageInput = None,
22
+ height: Optional[int] = None,
23
+ width: Optional[int] = None,
24
+ strength: float = 1.0,
25
+ num_inference_steps: int = 50,
26
+ guidance_scale: float = 7.5,
27
+ negative_prompt: Optional[Union[str, List[str]]] = None,
28
+ num_images_per_prompt: Optional[int] = 1,
29
+ eta: float = 0.0,
30
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
31
+ latents: Optional[torch.FloatTensor] = None,
32
+ prompt_embeds: Optional[torch.FloatTensor] = None,
33
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
34
+ output_type: Optional[str] = "pil",
35
+ return_dict: bool = True,
36
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
37
+ callback_steps: int = 1,
38
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
39
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
40
+ guess_mode: bool = False,
41
+ control_guidance_start: Union[float, List[float]] = 0.0,
42
+ control_guidance_end: Union[float, List[float]] = 1.0,
43
+ clip_skip: Optional[int] = None,
44
+ ## add
45
+ repeat_time: int = 4,
46
+ ##
47
+ **kwargs: Any,
48
+ ):
49
+ r""" """
50
+ controlnet = (
51
+ self.controlnet._orig_mod
52
+ if is_compiled_module(self.controlnet)
53
+ else self.controlnet
54
+ )
55
+
56
+ # self.init_filter()
57
+
58
+ # align format for control guidance
59
+ if not isinstance(control_guidance_start, list) and isinstance(
60
+ control_guidance_end, list
61
+ ):
62
+ control_guidance_start = len(control_guidance_end) * [
63
+ control_guidance_start
64
+ ]
65
+ elif not isinstance(control_guidance_end, list) and isinstance(
66
+ control_guidance_start, list
67
+ ):
68
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
69
+ elif not isinstance(control_guidance_start, list) and not isinstance(
70
+ control_guidance_end, list
71
+ ):
72
+ mult = (
73
+ len(controlnet.nets)
74
+ if isinstance(controlnet, MultiControlNetModel)
75
+ else 1
76
+ )
77
+ control_guidance_start, control_guidance_end = mult * [
78
+ control_guidance_start
79
+ ], mult * [control_guidance_end]
80
+
81
+ # 1. Check inputs. Raise error if not correct
82
+ self.check_inputs(
83
+ prompt,
84
+ control_image,
85
+ height,
86
+ width,
87
+ callback_steps,
88
+ negative_prompt,
89
+ prompt_embeds,
90
+ negative_prompt_embeds,
91
+ controlnet_conditioning_scale,
92
+ control_guidance_start,
93
+ control_guidance_end,
94
+ )
95
+
96
+ # 2. Define call parameters
97
+ if prompt is not None and isinstance(prompt, str):
98
+ batch_size = 1
99
+ elif prompt is not None and isinstance(prompt, list):
100
+ batch_size = len(prompt)
101
+ else:
102
+ batch_size = prompt_embeds.shape[0]
103
+
104
+ device = self._execution_device
105
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
106
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
107
+ # corresponds to doing no classifier free guidance.
108
+ do_classifier_free_guidance = guidance_scale > 1.0
109
+
110
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(
111
+ controlnet_conditioning_scale, float
112
+ ):
113
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(
114
+ controlnet.nets
115
+ )
116
+
117
+ global_pool_conditions = (
118
+ controlnet.config.global_pool_conditions
119
+ if isinstance(controlnet, ControlNetModel)
120
+ else controlnet.nets[0].config.global_pool_conditions
121
+ )
122
+ guess_mode = guess_mode or global_pool_conditions
123
+
124
+ # 3. Encode input prompt
125
+ text_encoder_lora_scale = (
126
+ cross_attention_kwargs.get("scale", None)
127
+ if cross_attention_kwargs is not None
128
+ else None
129
+ )
130
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
131
+ prompt,
132
+ device,
133
+ num_images_per_prompt,
134
+ do_classifier_free_guidance,
135
+ negative_prompt,
136
+ prompt_embeds=prompt_embeds,
137
+ negative_prompt_embeds=negative_prompt_embeds,
138
+ lora_scale=text_encoder_lora_scale,
139
+ clip_skip=clip_skip,
140
+ )
141
+ # For classifier free guidance, we need to do two forward passes.
142
+ # Here we concatenate the unconditional and text embeddings into a single batch
143
+ # to avoid doing two forward passes
144
+ if do_classifier_free_guidance:
145
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
146
+
147
+ # 4. Prepare image
148
+ if isinstance(controlnet, ControlNetModel):
149
+ control_image = self.prepare_control_image(
150
+ image=control_image,
151
+ width=width,
152
+ height=height,
153
+ batch_size=batch_size * num_images_per_prompt,
154
+ num_images_per_prompt=num_images_per_prompt,
155
+ device=device,
156
+ dtype=controlnet.dtype,
157
+ do_classifier_free_guidance=do_classifier_free_guidance,
158
+ guess_mode=guess_mode,
159
+ )
160
+ elif isinstance(controlnet, MultiControlNetModel):
161
+ control_images = []
162
+
163
+ for control_image_ in control_image:
164
+ control_image_ = self.prepare_control_image(
165
+ image=control_image_,
166
+ width=width,
167
+ height=height,
168
+ batch_size=batch_size * num_images_per_prompt,
169
+ num_images_per_prompt=num_images_per_prompt,
170
+ device=device,
171
+ dtype=controlnet.dtype,
172
+ do_classifier_free_guidance=do_classifier_free_guidance,
173
+ guess_mode=guess_mode,
174
+ )
175
+
176
+ control_images.append(control_image_)
177
+
178
+ control_image = control_images
179
+ else:
180
+ assert False
181
+
182
+ # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
183
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
184
+ init_image = init_image.to(dtype=torch.float32)
185
+
186
+ mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
187
+
188
+ masked_image = init_image * (mask < 0.5)
189
+ _, _, height, width = init_image.shape
190
+
191
+ # 5. Prepare timesteps
192
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
193
+ timesteps, num_inference_steps = self.get_timesteps(
194
+ num_inference_steps=num_inference_steps, strength=strength, device=device
195
+ )
196
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
197
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
198
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
199
+ is_strength_max = strength == 1.0
200
+
201
+ # 6. Prepare latent variables
202
+ num_channels_latents = self.vae.config.latent_channels
203
+ num_channels_unet = self.unet.config.in_channels
204
+ return_image_latents = True
205
+
206
+ latents_outputs = self.prepare_latents(
207
+ batch_size * num_images_per_prompt,
208
+ num_channels_latents,
209
+ height,
210
+ width,
211
+ prompt_embeds.dtype,
212
+ device,
213
+ generator,
214
+ latents,
215
+ image=init_image,
216
+ timestep=latent_timestep,
217
+ is_strength_max=is_strength_max,
218
+ return_noise=True,
219
+ return_image_latents=return_image_latents,
220
+ )
221
+
222
+ if return_image_latents:
223
+ latents, noise, image_latents = latents_outputs
224
+ else:
225
+ latents, noise = latents_outputs
226
+
227
+ # 7. Prepare mask latent variables
228
+ mask, masked_image_latents = self.prepare_mask_latents(
229
+ mask,
230
+ masked_image,
231
+ batch_size * num_images_per_prompt,
232
+ height,
233
+ width,
234
+ prompt_embeds.dtype,
235
+ device,
236
+ generator,
237
+ do_classifier_free_guidance,
238
+ )
239
+
240
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
241
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
242
+
243
+ # 7.1 Create tensor stating which controlnets to keep
244
+ controlnet_keep = []
245
+ for i in range(len(timesteps)):
246
+ keeps = [
247
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
248
+ for s, e in zip(control_guidance_start, control_guidance_end)
249
+ ]
250
+ controlnet_keep.append(
251
+ keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
252
+ )
253
+
254
+ # 8. Denoising loop
255
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
256
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
257
+ # for i, t in enumerate(timesteps):
258
+
259
+ ## modify
260
+ i = 0
261
+ reinject = repeat_time
262
+ while i < len(timesteps):
263
+ # expand the latents if we are doing classifier free guidance
264
+ t = timesteps[i]
265
+ latent_model_input = (
266
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
267
+ )
268
+ latent_model_input = self.scheduler.scale_model_input(
269
+ latent_model_input, t
270
+ )
271
+
272
+ # controlnet(s) inference
273
+ if guess_mode and do_classifier_free_guidance:
274
+ # Infer ControlNet only for the conditional batch.
275
+ control_model_input = latents
276
+ control_model_input = self.scheduler.scale_model_input(
277
+ control_model_input, t
278
+ )
279
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
280
+ else:
281
+ control_model_input = latent_model_input
282
+ controlnet_prompt_embeds = prompt_embeds
283
+
284
+ if isinstance(controlnet_keep[i], list):
285
+ cond_scale = [
286
+ c * s
287
+ for c, s in zip(
288
+ controlnet_conditioning_scale, controlnet_keep[i]
289
+ )
290
+ ]
291
+ else:
292
+ controlnet_cond_scale = controlnet_conditioning_scale
293
+ if isinstance(controlnet_cond_scale, list):
294
+ controlnet_cond_scale = controlnet_cond_scale[0]
295
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
296
+
297
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
298
+ control_model_input,
299
+ t,
300
+ encoder_hidden_states=controlnet_prompt_embeds,
301
+ controlnet_cond=control_image,
302
+ conditioning_scale=cond_scale,
303
+ guess_mode=guess_mode,
304
+ return_dict=False,
305
+ )
306
+
307
+ if guess_mode and do_classifier_free_guidance:
308
+ # Infered ControlNet only for the conditional batch.
309
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
310
+ # add 0 to the unconditional batch to keep it unchanged.
311
+ down_block_res_samples = [
312
+ torch.cat([torch.zeros_like(d), d])
313
+ for d in down_block_res_samples
314
+ ]
315
+ mid_block_res_sample = torch.cat(
316
+ [
317
+ torch.zeros_like(mid_block_res_sample),
318
+ mid_block_res_sample,
319
+ ]
320
+ )
321
+
322
+ # predict the noise residual
323
+ if num_channels_unet == 9:
324
+ latent_model_input = torch.cat(
325
+ [latent_model_input, mask, masked_image_latents], dim=1
326
+ )
327
+
328
+ noise_pred = self.unet(
329
+ latent_model_input,
330
+ t,
331
+ encoder_hidden_states=prompt_embeds,
332
+ cross_attention_kwargs=cross_attention_kwargs,
333
+ down_block_additional_residuals=down_block_res_samples,
334
+ mid_block_additional_residual=mid_block_res_sample,
335
+ return_dict=False,
336
+ )[0]
337
+
338
+ # perform guidance
339
+ if do_classifier_free_guidance:
340
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
341
+ noise_pred = noise_pred_uncond + guidance_scale * (
342
+ noise_pred_text - noise_pred_uncond
343
+ )
344
+
345
+ # compute the previous noisy sample x_t -> x_t-1
346
+ latents = self.scheduler.step(
347
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
348
+ )[0]
349
+
350
+ if num_channels_unet == 4:
351
+ init_latents_proper = image_latents
352
+ if do_classifier_free_guidance:
353
+ init_mask, _ = mask.chunk(2)
354
+ else:
355
+ init_mask = mask
356
+
357
+ if i < len(timesteps) - 1:
358
+ noise_timestep = timesteps[i + 1]
359
+ init_latents_proper = self.scheduler.add_noise(
360
+ init_latents_proper,
361
+ noise,
362
+ torch.tensor([noise_timestep]),
363
+ )
364
+
365
+ latents = (
366
+ 1 - init_mask
367
+ ) * init_latents_proper + init_mask * latents
368
+
369
+ i += 1
370
+
371
+ ## noise reinjection
372
+ if i > 0 and i < int(len(timesteps) - 1) and reinject > 0:
373
+ current_timestep = timesteps[i]
374
+ target_timestep = timesteps[i - 1]
375
+ new_nosie = torch.randn_like(latents)
376
+
377
+ # step back x_t-1 -> x_t
378
+ latents = self.scheduler.step_back(
379
+ latents,
380
+ new_nosie,
381
+ torch.tensor([current_timestep]),
382
+ torch.tensor([target_timestep]),
383
+ )
384
+ i -= 1
385
+ reinject -= 1
386
+ else:
387
+ # reinject = repeat_time
388
+
389
+ # schedule
390
+ if i >= int(len(timesteps) * 0.8):
391
+ reinject = 0
392
+ elif i >= int(len(timesteps) * 0.6):
393
+ reinject = max(0, repeat_time - 3)
394
+ elif i >= int(len(timesteps) * 0.4):
395
+ reinject = max(0, repeat_time - 2)
396
+ elif i >= int(len(timesteps) * 0.2):
397
+ reinject = max(0, repeat_time - 1)
398
+ else:
399
+ reinject = repeat_time
400
+
401
+ # call the callback, if provided
402
+ if i == len(timesteps) - 1 or (
403
+ (i + 1) > num_warmup_steps
404
+ and (i + 1) % self.scheduler.order == 0
405
+ ):
406
+ progress_bar.update()
407
+ if callback is not None and i % callback_steps == 0:
408
+ step_idx = i // getattr(self.scheduler, "order", 1)
409
+ callback(step_idx, t, latents)
410
+
411
+ # If we do sequential model offloading, let's offload unet and controlnet
412
+ # manually for max memory savings
413
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
414
+ self.unet.to("cpu")
415
+ self.controlnet.to("cpu")
416
+ torch.cuda.empty_cache()
417
+
418
+ if not output_type == "latent":
419
+ image = self.vae.decode(
420
+ latents / self.vae.config.scaling_factor,
421
+ return_dict=False,
422
+ generator=generator,
423
+ )[0]
424
+ image, has_nsfw_concept = self.run_safety_checker(
425
+ image, device, prompt_embeds.dtype
426
+ )
427
+ else:
428
+ image = latents
429
+ has_nsfw_concept = None
430
+
431
+ if has_nsfw_concept is None:
432
+ do_denormalize = [True] * image.shape[0]
433
+ else:
434
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
435
+
436
+ image = self.image_processor.postprocess(
437
+ image, output_type=output_type, do_denormalize=do_denormalize
438
+ )
439
+
440
+ # Offload all models
441
+ self.maybe_free_model_hooks()
442
+
443
+ if not return_dict:
444
+ return (image, has_nsfw_concept)
445
+
446
+ return StableDiffusionPipelineOutput(
447
+ images=image, nsfw_content_detected=has_nsfw_concept
448
+ )
src/pipelines/pipeline_stable_diffusion_outpaint.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import List, Union, Dict, Any, Callable, Optional, Tuple
4
+ from diffusers import StableDiffusionInpaintPipeline
5
+ from diffusers.utils import make_image_grid, load_image, deprecate
6
+ from diffusers.models import AsymmetricAutoencoderKL
7
+ from diffusers.image_processor import PipelineImageInput
8
+ from diffusers.pipelines.stable_diffusion.pipeline_output import (
9
+ StableDiffusionPipelineOutput,
10
+ )
11
+
12
+
13
+ class OutpaintPipeline(StableDiffusionInpaintPipeline):
14
+ @torch.no_grad()
15
+ def __call__(
16
+ self,
17
+ prompt: Union[str, List[str]] = None,
18
+ image: PipelineImageInput = None,
19
+ mask_image: PipelineImageInput = None,
20
+ control_image: PipelineImageInput = None,
21
+ masked_image_latents: torch.FloatTensor = None,
22
+ height: Optional[int] = None,
23
+ width: Optional[int] = None,
24
+ strength: float = 1.0,
25
+ num_inference_steps: int = 50,
26
+ guidance_scale: float = 7.5,
27
+ negative_prompt: Optional[Union[str, List[str]]] = None,
28
+ num_images_per_prompt: Optional[int] = 1,
29
+ eta: float = 0.0,
30
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
31
+ latents: Optional[torch.FloatTensor] = None,
32
+ prompt_embeds: Optional[torch.FloatTensor] = None,
33
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
34
+ output_type: Optional[str] = "pil",
35
+ return_dict: bool = True,
36
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
37
+ clip_skip: int = None,
38
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
39
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
40
+ ## add
41
+ repeat_time: int = 4,
42
+ ##
43
+ **kwargs,
44
+ ):
45
+ r"""
46
+ The call function to the pipeline for generation.
47
+
48
+ Args:
49
+ prompt (`str` or `List[str]`, *optional*):
50
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
51
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
52
+ `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
53
+ be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
54
+ tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
55
+ expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
56
+ expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
57
+ if passing latents directly it is not encoded again.
58
+ mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
59
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
60
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
61
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
62
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
63
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
64
+ 1)`, or `(H, W)`.
65
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
66
+ The height in pixels of the generated image.
67
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
68
+ The width in pixels of the generated image.
69
+ strength (`float`, *optional*, defaults to 1.0):
70
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
71
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
72
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
73
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
74
+ essentially ignores `image`.
75
+ num_inference_steps (`int`, *optional*, defaults to 50):
76
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
77
+ expense of slower inference. This parameter is modulated by `strength`.
78
+ guidance_scale (`float`, *optional*, defaults to 7.5):
79
+ A higher guidance scale value encourages the model to generate images closely linked to the text
80
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
81
+ negative_prompt (`str` or `List[str]`, *optional*):
82
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
83
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
84
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
85
+ The number of images to generate per prompt.
86
+ eta (`float`, *optional*, defaults to 0.0):
87
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
88
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
89
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
90
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
91
+ generation deterministic.
92
+ latents (`torch.FloatTensor`, *optional*):
93
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
94
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
95
+ tensor is generated by sampling using the supplied random `generator`.
96
+ prompt_embeds (`torch.FloatTensor`, *optional*):
97
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
98
+ provided, text embeddings are generated from the `prompt` input argument.
99
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
100
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
101
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
102
+ output_type (`str`, *optional*, defaults to `"pil"`):
103
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
104
+ return_dict (`bool`, *optional*, defaults to `True`):
105
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
106
+ plain tuple.
107
+ cross_attention_kwargs (`dict`, *optional*):
108
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
109
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
110
+ clip_skip (`int`, *optional*):
111
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
112
+ the output of the pre-final layer will be used for computing the prompt embeddings.
113
+ callback_on_step_end (`Callable`, *optional*):
114
+ A function that calls at the end of each denoising steps during the inference. The function is called
115
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
116
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
117
+ `callback_on_step_end_tensor_inputs`.
118
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
119
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
120
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
121
+ `._callback_tensor_inputs` attribute of your pipeine class.
122
+ Examples:
123
+
124
+ ```py
125
+ >>> import PIL
126
+ >>> import requests
127
+ >>> import torch
128
+ >>> from io import BytesIO
129
+
130
+ >>> from diffusers import StableDiffusionInpaintPipeline
131
+
132
+
133
+ >>> def download_image(url):
134
+ ... response = requests.get(url)
135
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
136
+
137
+
138
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
139
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
140
+
141
+ >>> init_image = download_image(img_url).resize((512, 512))
142
+ >>> mask_image = download_image(mask_url).resize((512, 512))
143
+
144
+ >>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
145
+ ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
146
+ ... )
147
+ >>> pipe = pipe.to("cuda")
148
+
149
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
150
+ >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
151
+ ```
152
+
153
+ Returns:
154
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
155
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
156
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
157
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
158
+ "not-safe-for-work" (nsfw) content.
159
+ """
160
+
161
+ callback = kwargs.pop("callback", None)
162
+ callback_steps = kwargs.pop("callback_steps", None)
163
+
164
+ if callback is not None:
165
+ deprecate(
166
+ "callback",
167
+ "1.0.0",
168
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
169
+ )
170
+ if callback_steps is not None:
171
+ deprecate(
172
+ "callback_steps",
173
+ "1.0.0",
174
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
175
+ )
176
+
177
+ # 0. Default height and width to unet
178
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
179
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
180
+
181
+ # 1. Check inputs
182
+ self.check_inputs(
183
+ prompt,
184
+ height,
185
+ width,
186
+ strength,
187
+ callback_steps,
188
+ negative_prompt,
189
+ prompt_embeds,
190
+ negative_prompt_embeds,
191
+ callback_on_step_end_tensor_inputs,
192
+ )
193
+
194
+ self._guidance_scale = guidance_scale
195
+ self._clip_skip = clip_skip
196
+ self._cross_attention_kwargs = cross_attention_kwargs
197
+
198
+ # 2. Define call parameters
199
+ if prompt is not None and isinstance(prompt, str):
200
+ batch_size = 1
201
+ elif prompt is not None and isinstance(prompt, list):
202
+ batch_size = len(prompt)
203
+ else:
204
+ batch_size = prompt_embeds.shape[0]
205
+
206
+ device = self._execution_device
207
+
208
+ # 3. Encode input prompt
209
+ text_encoder_lora_scale = (
210
+ cross_attention_kwargs.get("scale", None)
211
+ if cross_attention_kwargs is not None
212
+ else None
213
+ )
214
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
215
+ prompt,
216
+ device,
217
+ num_images_per_prompt,
218
+ self.do_classifier_free_guidance,
219
+ negative_prompt,
220
+ prompt_embeds=prompt_embeds,
221
+ negative_prompt_embeds=negative_prompt_embeds,
222
+ lora_scale=text_encoder_lora_scale,
223
+ clip_skip=self.clip_skip,
224
+ )
225
+ # For classifier free guidance, we need to do two forward passes.
226
+ # Here we concatenate the unconditional and text embeddings into a single batch
227
+ # to avoid doing two forward passes
228
+ if self.do_classifier_free_guidance:
229
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
230
+
231
+ # 4. set timesteps
232
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
233
+ timesteps, num_inference_steps = self.get_timesteps(
234
+ num_inference_steps=num_inference_steps, strength=strength, device=device
235
+ )
236
+ # check that number of inference steps is not < 1 - as this doesn't make sense
237
+ if num_inference_steps < 1:
238
+ raise ValueError(
239
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
240
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
241
+ )
242
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
243
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
244
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
245
+ is_strength_max = strength == 1.0
246
+
247
+ # 5. Preprocess mask and image
248
+
249
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
250
+ init_image = init_image.to(dtype=torch.float32)
251
+
252
+ # 6. Prepare latent variables
253
+ num_channels_latents = self.vae.config.latent_channels
254
+ num_channels_unet = self.unet.config.in_channels
255
+ return_image_latents = num_channels_unet == 4
256
+
257
+ latents_outputs = self.prepare_latents(
258
+ batch_size * num_images_per_prompt,
259
+ num_channels_latents,
260
+ height,
261
+ width,
262
+ prompt_embeds.dtype,
263
+ device,
264
+ generator,
265
+ latents,
266
+ image=init_image,
267
+ timestep=latent_timestep,
268
+ is_strength_max=is_strength_max,
269
+ return_noise=True,
270
+ return_image_latents=return_image_latents,
271
+ )
272
+
273
+ if return_image_latents:
274
+ latents, noise, image_latents = latents_outputs
275
+ else:
276
+ latents, noise = latents_outputs
277
+
278
+ # 7. Prepare mask latent variables
279
+ mask_condition = self.mask_processor.preprocess(
280
+ mask_image, height=height, width=width
281
+ )
282
+
283
+ if masked_image_latents is None:
284
+ masked_image = init_image * (mask_condition < 0.5)
285
+ else:
286
+ masked_image = masked_image_latents
287
+
288
+ mask, masked_image_latents = self.prepare_mask_latents(
289
+ mask_condition,
290
+ masked_image,
291
+ batch_size * num_images_per_prompt,
292
+ height,
293
+ width,
294
+ prompt_embeds.dtype,
295
+ device,
296
+ generator,
297
+ self.do_classifier_free_guidance,
298
+ )
299
+
300
+ # 8. Check that sizes of mask, masked image and latents match
301
+ if num_channels_unet == 9:
302
+ # default case for runwayml/stable-diffusion-inpainting
303
+ num_channels_mask = mask.shape[1]
304
+ num_channels_masked_image = masked_image_latents.shape[1]
305
+ if (
306
+ num_channels_latents + num_channels_mask + num_channels_masked_image
307
+ != self.unet.config.in_channels
308
+ ):
309
+ raise ValueError(
310
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
311
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
312
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
313
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
314
+ " `pipeline.unet` or your `mask_image` or `image` input."
315
+ )
316
+ elif num_channels_unet != 4:
317
+ raise ValueError(
318
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
319
+ )
320
+
321
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
322
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
323
+
324
+ # 9.5 Optionally get Guidance Scale Embedding
325
+ timestep_cond = None
326
+ if self.unet.config.time_cond_proj_dim is not None:
327
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
328
+ batch_size * num_images_per_prompt
329
+ )
330
+ timestep_cond = self.get_guidance_scale_embedding(
331
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
332
+ ).to(device=device, dtype=latents.dtype)
333
+
334
+ # 10. Denoising loop
335
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
336
+ self._num_timesteps = len(timesteps)
337
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
338
+ # for i in range(len(timesteps)):
339
+
340
+ ## modify
341
+ i = 0
342
+ reinject = repeat_time
343
+ while i < len(timesteps):
344
+ # expand the latents if we are doing classifier free guidance
345
+ latent_model_input = (
346
+ torch.cat([latents] * 2)
347
+ if self.do_classifier_free_guidance
348
+ else latents
349
+ )
350
+
351
+ # concat latents, mask, masked_image_latents in the channel dimension
352
+ latent_model_input = self.scheduler.scale_model_input(
353
+ latent_model_input, timesteps[i]
354
+ )
355
+
356
+ if num_channels_unet == 9:
357
+ latent_model_input = torch.cat(
358
+ [latent_model_input, mask, masked_image_latents], dim=1
359
+ )
360
+
361
+ # predict the noise residual
362
+ noise_pred = self.unet(
363
+ latent_model_input,
364
+ timesteps[i],
365
+ encoder_hidden_states=prompt_embeds,
366
+ timestep_cond=timestep_cond,
367
+ cross_attention_kwargs=self.cross_attention_kwargs,
368
+ return_dict=False,
369
+ )[0]
370
+
371
+ # perform guidance
372
+ if self.do_classifier_free_guidance:
373
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
374
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
375
+ noise_pred_text - noise_pred_uncond
376
+ )
377
+
378
+ # compute the previous noisy sample x_t -> x_t-1
379
+ latents = self.scheduler.step(
380
+ noise_pred,
381
+ timesteps[i],
382
+ latents,
383
+ **extra_step_kwargs,
384
+ return_dict=False,
385
+ )[0]
386
+ if num_channels_unet == 4:
387
+ init_latents_proper = image_latents
388
+ if self.do_classifier_free_guidance:
389
+ init_mask, _ = mask.chunk(2)
390
+ else:
391
+ init_mask = mask
392
+
393
+ if i < len(timesteps) - 1:
394
+ noise_timestep = timesteps[i + 1]
395
+ init_latents_proper = self.scheduler.add_noise(
396
+ init_latents_proper, noise, torch.tensor([noise_timestep])
397
+ )
398
+
399
+ latents = (
400
+ 1 - init_mask
401
+ ) * init_latents_proper + init_mask * latents
402
+
403
+ if callback_on_step_end is not None:
404
+ callback_kwargs = {}
405
+ for k in callback_on_step_end_tensor_inputs:
406
+ callback_kwargs[k] = locals()[k]
407
+ callback_outputs = callback_on_step_end(
408
+ self, i, timesteps[i], callback_kwargs
409
+ )
410
+
411
+ latents = callback_outputs.pop("latents", latents)
412
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
413
+ negative_prompt_embeds = callback_outputs.pop(
414
+ "negative_prompt_embeds", negative_prompt_embeds
415
+ )
416
+ mask = callback_outputs.pop("mask", mask)
417
+ masked_image_latents = callback_outputs.pop(
418
+ "masked_image_latents", masked_image_latents
419
+ )
420
+
421
+ # # call the callback, if provided
422
+ # if i == len(timesteps) - 1 or (
423
+ # (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
424
+ # ):
425
+ # progress_bar.update()
426
+ # if callback is not None and i % callback_steps == 0:
427
+ # step_idx = i // getattr(self.scheduler, "order", 1)
428
+ # callback(step_idx, timesteps[i], latents)
429
+
430
+ i += 1
431
+
432
+ ## noise reinjection
433
+ if i > 0 and i < int(len(timesteps) - 1) and reinject != 0:
434
+ current_timestep = timesteps[i]
435
+ target_timestep = timesteps[i - 1]
436
+ new_nosie = torch.randn_like(latents)
437
+
438
+ # step back x_t-1 -> x_t
439
+ latents = self.scheduler.step_back(
440
+ latents,
441
+ new_nosie,
442
+ torch.tensor([current_timestep]),
443
+ torch.tensor([target_timestep]),
444
+ )
445
+ i -= 1
446
+ reinject -= 1
447
+ else:
448
+ # reinject = repeat_time
449
+
450
+ # schedule
451
+ if i >= int(len(timesteps) * 0.85):
452
+ reinject = 0
453
+ elif i >= int(len(timesteps) * 0.8):
454
+ reinject = 1
455
+ elif i >= int(len(timesteps) * 0.7):
456
+ reinject = 2
457
+ elif i >= int(len(timesteps) * 0.5):
458
+ reinject = 3
459
+ else:
460
+ reinject = 4
461
+
462
+ # call the callback, if provided
463
+ if i == len(timesteps) - 1 or (
464
+ (i + 1) > num_warmup_steps
465
+ and (i + 1) % self.scheduler.order == 0
466
+ ):
467
+ progress_bar.update()
468
+ if callback is not None and i % callback_steps == 0:
469
+ step_idx = i // getattr(self.scheduler, "order", 1)
470
+ callback(step_idx, timesteps[i], latents)
471
+
472
+ if not output_type == "latent":
473
+ condition_kwargs = {}
474
+ if isinstance(self.vae, AsymmetricAutoencoderKL):
475
+ init_image = init_image.to(
476
+ device=device, dtype=masked_image_latents.dtype
477
+ )
478
+ init_image_condition = init_image.clone()
479
+ init_image = self._encode_vae_image(init_image, generator=generator)
480
+ mask_condition = mask_condition.to(
481
+ device=device, dtype=masked_image_latents.dtype
482
+ )
483
+ condition_kwargs = {
484
+ "image": init_image_condition,
485
+ "mask": mask_condition,
486
+ }
487
+ image = self.vae.decode(
488
+ latents / self.vae.config.scaling_factor,
489
+ return_dict=False,
490
+ generator=generator,
491
+ **condition_kwargs,
492
+ )[0]
493
+ image, has_nsfw_concept = self.run_safety_checker(
494
+ image, device, prompt_embeds.dtype
495
+ )
496
+ else:
497
+ image = latents
498
+ has_nsfw_concept = None
499
+
500
+ if has_nsfw_concept is None:
501
+ do_denormalize = [True] * image.shape[0]
502
+ else:
503
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
504
+
505
+ image = self.image_processor.postprocess(
506
+ image, output_type=output_type, do_denormalize=do_denormalize
507
+ )
508
+
509
+ # Offload all models
510
+ self.maybe_free_model_hooks()
511
+
512
+ if not return_dict:
513
+ return (image, has_nsfw_concept)
514
+
515
+ return StableDiffusionPipelineOutput(
516
+ images=image, nsfw_content_detected=has_nsfw_concept
517
+ )
src/schedulers/__pycache__/scheduling_pndm.cpython-39.pyc ADDED
Binary file (3.73 kB). View file
 
src/schedulers/scheduling_pndm.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Optional, Tuple, Union
3
+ from diffusers import PNDMScheduler
4
+ from diffusers.schedulers.scheduling_utils import SchedulerOutput
5
+
6
+
7
+ class CustomScheduler(PNDMScheduler):
8
+ def step_plms(
9
+ self,
10
+ model_output: torch.FloatTensor,
11
+ timestep: int,
12
+ sample: torch.FloatTensor,
13
+ return_dict: bool = True,
14
+ ) -> Union[SchedulerOutput, Tuple]:
15
+ """
16
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
17
+ the linear multistep method. It performs one forward pass multiple times to approximate the solution.
18
+
19
+ Args:
20
+ model_output (`torch.FloatTensor`):
21
+ The direct output from learned diffusion model.
22
+ timestep (`int`):
23
+ The current discrete timestep in the diffusion chain.
24
+ sample (`torch.FloatTensor`):
25
+ A current instance of a sample created by the diffusion process.
26
+ return_dict (`bool`):
27
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
28
+
29
+ Returns:
30
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
31
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
32
+ tuple is returned where the first element is the sample tensor.
33
+
34
+ """
35
+ if self.num_inference_steps is None:
36
+ raise ValueError(
37
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
38
+ )
39
+
40
+ if not self.config.skip_prk_steps and len(self.ets) < 3:
41
+ raise ValueError(
42
+ f"{self.__class__} can only be run AFTER scheduler has been run "
43
+ "in 'prk' mode for at least 12 iterations "
44
+ "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
45
+ "for more information."
46
+ )
47
+
48
+ prev_timestep = (
49
+ timestep - self.config.num_train_timesteps // self.num_inference_steps
50
+ )
51
+
52
+ if self.counter != 1:
53
+ self.ets = self.ets[-3:]
54
+ self.ets.append(model_output)
55
+ else:
56
+ prev_timestep = timestep
57
+ timestep = (
58
+ timestep + self.config.num_train_timesteps // self.num_inference_steps
59
+ )
60
+
61
+ if len(self.ets) == 1 and self.counter == 0:
62
+ model_output = model_output
63
+ self.cur_sample = sample
64
+ elif len(self.ets) == 1 and self.counter == 1:
65
+ model_output = (model_output + self.ets[-1]) / 2
66
+ sample = self.cur_sample
67
+ # self.cur_sample = None
68
+ elif len(self.ets) == 2:
69
+ model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
70
+ elif len(self.ets) == 3:
71
+ model_output = (
72
+ 23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]
73
+ ) / 12
74
+ else:
75
+ model_output = (1 / 24) * (
76
+ 55 * self.ets[-1]
77
+ - 59 * self.ets[-2]
78
+ + 37 * self.ets[-3]
79
+ - 9 * self.ets[-4]
80
+ )
81
+
82
+ prev_sample = self._get_prev_sample(
83
+ sample, timestep, prev_timestep, model_output
84
+ )
85
+ self.counter += 1
86
+
87
+ if not return_dict:
88
+ return (prev_sample,)
89
+
90
+ return SchedulerOutput(prev_sample=prev_sample)
91
+
92
+ def step_back(
93
+ self,
94
+ current_samples: torch.FloatTensor,
95
+ noise: torch.FloatTensor,
96
+ current_timesteps: torch.IntTensor,
97
+ target_timesteps: torch.IntTensor,
98
+ ):
99
+ """Custom function for stepping back in the diffusion process."""
100
+
101
+ assert current_timesteps <= target_timesteps
102
+ alphas_cumprod = self.alphas_cumprod.to(
103
+ device=current_samples.device, dtype=current_samples.dtype
104
+ )
105
+ target_timesteps = target_timesteps.to(current_samples.device)
106
+ current_timesteps = current_timesteps.to(current_samples.device)
107
+ alpha_prod_target = alphas_cumprod[target_timesteps]
108
+ alpha_prod_target = alpha_prod_target.flatten()
109
+ alpha_prod_current = alphas_cumprod[current_timesteps]
110
+ alpha_prod_current = alpha_prod_current.flatten()
111
+ alpha_prod = alpha_prod_target / alpha_prod_current
112
+
113
+ sqrt_alpha_prod = alpha_prod**0.5
114
+ sqrt_one_minus_alpha_prod = (1 - alpha_prod) ** 0.5
115
+
116
+ while len(sqrt_alpha_prod.shape) < len(current_samples.shape):
117
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
118
+ while len(sqrt_one_minus_alpha_prod.shape) < len(current_samples.shape):
119
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
120
+
121
+ noisy_samples = (
122
+ sqrt_alpha_prod * current_samples + sqrt_one_minus_alpha_prod * noise
123
+ )
124
+ self.counter -= 1
125
+
126
+ return noisy_samples
utils/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (28.2 kB). View file
 
utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (8.55 kB). View file
 
utils/dataset.py ADDED
@@ -0,0 +1,1304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import glob
4
+ import random
5
+ import timeit
6
+ import numpy as np
7
+ import skimage
8
+ import yaml
9
+ import torch
10
+ import torchvision.transforms as transforms
11
+ import torchvision.transforms.functional as TF
12
+ from PIL import Image
13
+ from torch.utils.data import Dataset
14
+ from torch.distributions import Normal
15
+
16
+ # from utils.utils import RGB2YCbCr
17
+
18
+
19
+ class RandomGammaCorrection(object):
20
+ def __init__(self, gamma=None):
21
+ self.gamma = gamma
22
+
23
+ def __call__(self, image):
24
+ if self.gamma == None:
25
+ # more chances of selecting 0 (original image)
26
+ gammas = [0.5, 1, 2]
27
+ self.gamma = random.choice(gammas)
28
+ return TF.adjust_gamma(image, self.gamma, gain=1)
29
+ elif isinstance(self.gamma, tuple):
30
+ gamma = random.uniform(*self.gamma)
31
+ return TF.adjust_gamma(image, gamma, gain=1)
32
+ elif self.gamma == 0:
33
+ return image
34
+ else:
35
+ return TF.adjust_gamma(image, self.gamma, gain=1)
36
+
37
+
38
+ def remove_background(image):
39
+ # the input of the image is PIL.Image form with [H,W,C]
40
+ image = np.float32(np.array(image))
41
+ _EPS = 1e-7
42
+ rgb_max = np.max(image, (0, 1))
43
+ rgb_min = np.min(image, (0, 1))
44
+ image = (image - rgb_min) * rgb_max / (rgb_max - rgb_min + _EPS)
45
+ image = torch.from_numpy(image)
46
+ return image
47
+
48
+
49
+ def glod_from_folder(folder_list, index_list):
50
+ ext = ["png", "jpeg", "jpg", "bmp", "tif"]
51
+ index_dict = {}
52
+ for i, folder_name in enumerate(folder_list):
53
+ data_list = []
54
+ [data_list.extend(glob.glob(folder_name + "/*." + e)) for e in ext]
55
+ data_list.sort()
56
+ index_dict[index_list[i]] = data_list
57
+ return index_dict
58
+
59
+
60
+ class Flare_Image_Loader(Dataset):
61
+ def __init__(self, image_path, transform_base, transform_flare, mask_type=None):
62
+ self.ext = ["png", "jpeg", "jpg", "bmp", "tif"]
63
+ self.data_list = []
64
+ [self.data_list.extend(glob.glob(image_path + "/*." + e)) for e in self.ext]
65
+ self.flare_dict = {}
66
+ self.flare_list = []
67
+ self.flare_name_list = []
68
+
69
+ self.reflective_flag = False
70
+ self.reflective_dict = {}
71
+ self.reflective_list = []
72
+ self.reflective_name_list = []
73
+
74
+ self.light_flag = False
75
+ self.light_dict = {}
76
+ self.light_list = []
77
+ self.light_name_list = []
78
+
79
+ self.mask_type = (
80
+ mask_type # It is a str which may be None,"luminance" or "color"
81
+ )
82
+
83
+ self.img_size = transform_base["img_size"]
84
+
85
+ self.transform_base = transforms.Compose(
86
+ [
87
+ transforms.RandomCrop(
88
+ (self.img_size, self.img_size),
89
+ pad_if_needed=True,
90
+ padding_mode="reflect",
91
+ ),
92
+ transforms.RandomHorizontalFlip(),
93
+ # transforms.RandomVerticalFlip(),
94
+ ]
95
+ )
96
+
97
+ self.transform_flare = transforms.Compose(
98
+ [
99
+ transforms.RandomAffine(
100
+ degrees=(0, 360),
101
+ scale=(transform_flare["scale_min"], transform_flare["scale_max"]),
102
+ translate=(
103
+ transform_flare["translate"] / 1440,
104
+ transform_flare["translate"] / 1440,
105
+ ),
106
+ shear=(-transform_flare["shear"], transform_flare["shear"]),
107
+ ),
108
+ transforms.CenterCrop((self.img_size, self.img_size)),
109
+ transforms.RandomHorizontalFlip(),
110
+ transforms.RandomVerticalFlip(),
111
+ ]
112
+ )
113
+
114
+ self.normalize = transforms.Compose(
115
+ [
116
+ transforms.Normalize([0.5], [0.5]),
117
+ ]
118
+ )
119
+
120
+ self.data_ratio = []
121
+
122
+ def lightsource_crop(self, matrix):
123
+ """Find the largest rectangle of 1s in a binary matrix."""
124
+
125
+ def largestRectangleArea(heights):
126
+ heights.append(0)
127
+ stack = [-1]
128
+ max_area = 0
129
+ max_rectangle = (0, 0, 0, 0) # (area, left, right, height)
130
+
131
+ for i in range(len(heights)):
132
+ while heights[i] < heights[stack[-1]]:
133
+ h = heights[stack.pop()]
134
+ w = i - stack[-1] - 1
135
+ area = h * w
136
+ if area > max_area:
137
+ max_area = area
138
+ max_rectangle = (area, stack[-1] + 1, i - 1, h)
139
+ stack.append(i)
140
+
141
+ heights.pop()
142
+ return max_rectangle
143
+
144
+ max_area = 0
145
+ max_rectangle = [0, 0, 0, 0] # (left, right, top, bottom)
146
+ heights = torch.zeros(matrix.shape[1])
147
+
148
+ for row in range(matrix.shape[0]):
149
+ temp = 1 - matrix[row]
150
+ heights = (heights + temp) * temp
151
+
152
+ area, left, right, height = largestRectangleArea(heights.tolist())
153
+ if area > max_area:
154
+ max_area = area
155
+ max_rectangle = [int(left), int(right), int(row - height + 1), int(row)]
156
+
157
+ return torch.tensor(max_rectangle)
158
+
159
+ def __getitem__(self, index):
160
+ # load base image
161
+ img_path = self.data_list[index]
162
+ base_img = Image.open(img_path).convert("RGB")
163
+
164
+ gamma = np.random.uniform(1.8, 2.2)
165
+ to_tensor = transforms.ToTensor()
166
+ adjust_gamma = RandomGammaCorrection(gamma)
167
+ adjust_gamma_reverse = RandomGammaCorrection(1 / gamma)
168
+ color_jitter = transforms.ColorJitter(brightness=(0.8, 3), hue=0.0)
169
+ if self.transform_base is not None:
170
+ base_img = to_tensor(base_img)
171
+ base_img = adjust_gamma(base_img)
172
+ base_img = self.transform_base(base_img)
173
+ else:
174
+ base_img = to_tensor(base_img)
175
+ base_img = adjust_gamma(base_img)
176
+ sigma_chi = 0.01 * np.random.chisquare(df=1)
177
+ base_img = Normal(base_img, sigma_chi).sample()
178
+ gain = np.random.uniform(0.5, 1.2)
179
+ flare_DC_offset = np.random.uniform(-0.02, 0.02)
180
+ base_img = gain * base_img
181
+ base_img = torch.clamp(base_img, min=0, max=1)
182
+
183
+ choice_dataset = random.choices(
184
+ [i for i in range(len(self.flare_list))], self.data_ratio
185
+ )[0]
186
+ choice_index = random.randint(0, len(self.flare_list[choice_dataset]) - 1)
187
+
188
+ # load flare and light source image
189
+ if self.light_flag:
190
+ assert len(self.flare_list) == len(
191
+ self.light_list
192
+ ), "Error, number of light source and flares dataset no match!"
193
+ for i in range(len(self.flare_list)):
194
+ assert len(self.flare_list[i]) == len(
195
+ self.light_list[i]
196
+ ), f"Error, number of light source and flares no match in {i} dataset!"
197
+ flare_path = self.flare_list[choice_dataset][choice_index]
198
+ light_path = self.light_list[choice_dataset][choice_index]
199
+ light_img = Image.open(light_path).convert("RGB")
200
+ light_img = to_tensor(light_img)
201
+ light_img = adjust_gamma(light_img)
202
+ else:
203
+ flare_path = self.flare_list[choice_dataset][choice_index]
204
+ flare_img = Image.open(flare_path).convert("RGB")
205
+ if self.reflective_flag:
206
+ reflective_path_list = self.reflective_list[choice_dataset]
207
+ if len(reflective_path_list) != 0:
208
+ reflective_path = random.choice(reflective_path_list)
209
+ reflective_img = Image.open(reflective_path).convert("RGB")
210
+ else:
211
+ reflective_img = None
212
+
213
+ flare_img = to_tensor(flare_img)
214
+ flare_img = adjust_gamma(flare_img)
215
+
216
+ if self.reflective_flag and reflective_img is not None:
217
+ reflective_img = to_tensor(reflective_img)
218
+ reflective_img = adjust_gamma(reflective_img)
219
+ flare_img = torch.clamp(flare_img + reflective_img, min=0, max=1)
220
+
221
+ flare_img = remove_background(flare_img)
222
+
223
+ if self.transform_flare is not None:
224
+ if self.light_flag:
225
+ flare_merge = torch.cat((flare_img, light_img), dim=0)
226
+ flare_merge = self.transform_flare(flare_merge)
227
+ else:
228
+ flare_img = self.transform_flare(flare_img)
229
+
230
+ # change color
231
+ if self.light_flag:
232
+ # flare_merge=color_jitter(flare_merge)
233
+ flare_img, light_img = torch.split(flare_merge, 3, dim=0)
234
+ else:
235
+ flare_img = color_jitter(flare_img)
236
+
237
+ # flare blur
238
+ blur_transform = transforms.GaussianBlur(21, sigma=(0.1, 3.0))
239
+ flare_img = blur_transform(flare_img)
240
+ # flare_img=flare_img+flare_DC_offset
241
+ flare_img = torch.clamp(flare_img, min=0, max=1)
242
+
243
+ # merge image
244
+ merge_img = flare_img + base_img
245
+ merge_img = torch.clamp(merge_img, min=0, max=1)
246
+ if self.light_flag:
247
+ base_img = base_img + light_img
248
+ base_img = torch.clamp(base_img, min=0, max=1)
249
+ flare_img = flare_img - light_img
250
+ flare_img = torch.clamp(flare_img, min=0, max=1)
251
+
252
+ flare_mask = None
253
+ if self.mask_type == None:
254
+ return {
255
+ "gt": adjust_gamma_reverse(base_img),
256
+ "flare": adjust_gamma_reverse(flare_img),
257
+ "lq": adjust_gamma_reverse(merge_img),
258
+ "gamma": gamma,
259
+ }
260
+
261
+ elif self.mask_type == "luminance":
262
+ # calculate mask (the mask is 3 channel)
263
+ one = torch.ones_like(base_img)
264
+ zero = torch.zeros_like(base_img)
265
+
266
+ luminance = 0.3 * flare_img[0] + 0.59 * flare_img[1] + 0.11 * flare_img[2]
267
+ threshold_value = 0.99**gamma
268
+ flare_mask = torch.where(luminance > threshold_value, one, zero)
269
+
270
+ elif self.mask_type == "color":
271
+ one = torch.ones_like(base_img)
272
+ zero = torch.zeros_like(base_img)
273
+
274
+ threshold_value = 0.99**gamma
275
+ flare_mask = torch.where(merge_img > threshold_value, one, zero)
276
+
277
+ elif self.mask_type == "flare":
278
+ one = torch.ones_like(base_img)
279
+ zero = torch.zeros_like(base_img)
280
+
281
+ threshold_value = 0.7**gamma
282
+ flare_mask = torch.where(flare_img > threshold_value, one, zero)
283
+
284
+ elif self.mask_type == "light":
285
+ # Depreciated: we dont need light mask anymore
286
+ one = torch.ones_like(base_img)
287
+ zero = torch.zeros_like(base_img)
288
+
289
+ luminance = 0.3 * light_img[0] + 0.59 * light_img[1] + 0.11 * light_img[2]
290
+ threshold_value = 0.01
291
+ flare_mask = torch.where(luminance > threshold_value, one, zero)
292
+
293
+ light_source_cond = torch.zeros_like(flare_mask[0])
294
+ light_source_cond = (flare_mask[0] + flare_mask[1] + flare_mask[2]) > 0
295
+ light_source_cond = light_source_cond.float()
296
+ light_source_cond = torch.repeat_interleave(
297
+ light_source_cond[None, ...], 3, dim=0
298
+ )
299
+
300
+ # box = self.crop(light_source_cond[0])
301
+ box = self.lightsource_crop(light_source_cond[0])
302
+
303
+ # random int between -15 ~ 15
304
+ margin = random.randint(-15, 15)
305
+
306
+ if box[0] - margin >= 0:
307
+ box[0] -= margin
308
+ if box[1] + margin < self.img_size:
309
+ box[1] += margin
310
+ if box[2] - margin >= 0:
311
+ box[2] -= margin
312
+ if box[3] + margin < self.img_size:
313
+ box[3] += margin
314
+
315
+ top, bottom, left, right = box[2], box[3], box[0], box[1]
316
+
317
+ merge_img = adjust_gamma_reverse(merge_img)
318
+
319
+ cropped_mask = torch.ones((self.img_size, self.img_size))
320
+ cropped_mask[top : bottom + 1, left : right + 1] = False
321
+ cropped_mask = torch.repeat_interleave(cropped_mask[None, ...], 1, dim=0)
322
+
323
+ channel3_mask = cropped_mask.repeat(3, 1, 1)
324
+ masked_img = merge_img * (1 - channel3_mask)
325
+ masked_img[channel3_mask == 1] = 0.5
326
+
327
+ return {
328
+ # add
329
+ "pixel_values": self.normalize(merge_img),
330
+ "masks": cropped_mask,
331
+ "masked_images": self.normalize(masked_img),
332
+ "conditioning_pixel_values": light_source_cond,
333
+ }
334
+
335
+ def __len__(self):
336
+ return len(self.data_list)
337
+
338
+ def load_scattering_flare(self, flare_name, flare_path):
339
+ flare_list = []
340
+ [flare_list.extend(glob.glob(flare_path + "/*." + e)) for e in self.ext]
341
+ flare_list = sorted(flare_list)
342
+ self.flare_name_list.append(flare_name)
343
+ self.flare_dict[flare_name] = flare_list
344
+ self.flare_list.append(flare_list)
345
+ len_flare_list = len(self.flare_dict[flare_name])
346
+ if len_flare_list == 0:
347
+ print("ERROR: scattering flare images are not loaded properly")
348
+ else:
349
+ print(
350
+ "Scattering Flare Image:",
351
+ flare_name,
352
+ " is loaded successfully with examples",
353
+ str(len_flare_list),
354
+ )
355
+ # print("Now we have", len(self.flare_list), "scattering flare images")
356
+
357
+ def load_light_source(self, light_name, light_path):
358
+ # The number of the light source images should match the number of scattering flares
359
+ light_list = []
360
+ [light_list.extend(glob.glob(light_path + "/*." + e)) for e in self.ext]
361
+ light_list = sorted(light_list)
362
+ self.flare_name_list.append(light_name)
363
+ self.light_dict[light_name] = light_list
364
+ self.light_list.append(light_list)
365
+ len_light_list = len(self.light_dict[light_name])
366
+
367
+ if len_light_list == 0:
368
+ print("ERROR: Light Source images are not loaded properly")
369
+ else:
370
+ self.light_flag = True
371
+ print(
372
+ "Light Source Image:",
373
+ light_name,
374
+ " is loaded successfully with examples",
375
+ str(len_light_list),
376
+ )
377
+ # print("Now we have", len(self.light_list), "light source images")
378
+
379
+ def load_reflective_flare(self, reflective_name, reflective_path):
380
+ if reflective_path is None:
381
+ reflective_list = []
382
+ else:
383
+ reflective_list = []
384
+ [
385
+ reflective_list.extend(glob.glob(reflective_path + "/*." + e))
386
+ for e in self.ext
387
+ ]
388
+ reflective_list = sorted(reflective_list)
389
+ self.reflective_name_list.append(reflective_name)
390
+ self.reflective_dict[reflective_name] = reflective_list
391
+ self.reflective_list.append(reflective_list)
392
+ len_reflective_list = len(self.reflective_dict[reflective_name])
393
+ if len_reflective_list == 0 and reflective_path is not None:
394
+ print("ERROR: reflective flare images are not loaded properly")
395
+ else:
396
+ self.reflective_flag = True
397
+ print(
398
+ "Reflective Flare Image:",
399
+ reflective_name,
400
+ " is loaded successfully with examples",
401
+ str(len_reflective_list),
402
+ )
403
+ # print("Now we have", len(self.reflective_list), "refelctive flare images")
404
+
405
+
406
+ class Flare7kpp_Pair_Loader(Flare_Image_Loader):
407
+ def __init__(self, config):
408
+ Flare_Image_Loader.__init__(
409
+ self,
410
+ config["image_path"],
411
+ config["transform_base"],
412
+ config["transform_flare"],
413
+ config["mask_type"],
414
+ )
415
+ scattering_dict = config["scattering_dict"]
416
+ reflective_dict = config["reflective_dict"]
417
+ light_dict = config["light_dict"]
418
+
419
+ # defualt not use light mask if opt['use_light_mask'] is not declared
420
+ if "data_ratio" not in config or len(config["data_ratio"]) == 0:
421
+ self.data_ratio = [1] * len(scattering_dict)
422
+ else:
423
+ self.data_ratio = config["data_ratio"]
424
+
425
+ if len(scattering_dict) != 0:
426
+ for key in scattering_dict.keys():
427
+ self.load_scattering_flare(key, scattering_dict[key])
428
+ if len(reflective_dict) != 0:
429
+ for key in reflective_dict.keys():
430
+ self.load_reflective_flare(key, reflective_dict[key])
431
+ if len(light_dict) != 0:
432
+ for key in light_dict.keys():
433
+ self.load_light_source(key, light_dict[key])
434
+
435
+
436
+ class Lightsource_Regress_Loader(Flare7kpp_Pair_Loader):
437
+ def __init__(self, config, num_lights=4):
438
+ Flare7kpp_Pair_Loader.__init__(self, config)
439
+ self.transform_flare = transforms.Compose(
440
+ [
441
+ transforms.RandomAffine(
442
+ degrees=(0, 360),
443
+ scale=(
444
+ config["transform_flare"]["scale_min"],
445
+ config["transform_flare"]["scale_max"],
446
+ ),
447
+ shear=(
448
+ -config["transform_flare"]["shear"],
449
+ config["transform_flare"]["shear"],
450
+ ),
451
+ ),
452
+ # transforms.CenterCrop((self.img_size, self.img_size)),
453
+ ]
454
+ )
455
+
456
+ self.mask_type = "light"
457
+ self.num_lights = num_lights
458
+
459
+ def __getitem__(self, index):
460
+ # load base image
461
+ img_path = self.data_list[index]
462
+ base_img = Image.open(img_path).convert("RGB")
463
+
464
+ gamma = np.random.uniform(1.8, 2.2)
465
+ to_tensor = transforms.ToTensor()
466
+ adjust_gamma = RandomGammaCorrection(gamma)
467
+ adjust_gamma_reverse = RandomGammaCorrection(1 / gamma)
468
+ color_jitter = transforms.ColorJitter(brightness=(0.8, 3), hue=0.0)
469
+
470
+ base_img = to_tensor(base_img)
471
+ base_img = adjust_gamma(base_img)
472
+ if self.transform_base is not None:
473
+ base_img = self.transform_base(base_img)
474
+
475
+ sigma_chi = 0.01 * np.random.chisquare(df=1)
476
+ base_img = Normal(base_img, sigma_chi).sample()
477
+ gain = np.random.uniform(0.5, 1.2)
478
+ base_img = gain * base_img
479
+ base_img = torch.clamp(base_img, min=0, max=1)
480
+
481
+ # init flare and light imgs
482
+ flare_imgs = []
483
+ light_imgs = []
484
+ position = [
485
+ [[-224, 0], [-224, 0]],
486
+ [[-224, 0], [0, 224]],
487
+ [[0, 224], [-224, 0]],
488
+ [[0, 224], [0, 224]],
489
+ ]
490
+ axis = random.sample(range(4), 4)
491
+ axis[-1] = axis[0]
492
+ flare_nums = int(
493
+ random.random() * self.num_lights + 1
494
+ ) # random number of flares from 1 to 4
495
+
496
+ for fn in range(flare_nums):
497
+ choice_dataset = random.choices(
498
+ [i for i in range(len(self.flare_list))], self.data_ratio
499
+ )[0]
500
+ choice_index = random.randint(0, len(self.flare_list[choice_dataset]) - 1)
501
+
502
+ flare_path = self.flare_list[choice_dataset][choice_index]
503
+ flare_img = Image.open(flare_path).convert("RGB")
504
+ flare_img = to_tensor(flare_img)
505
+ flare_img = adjust_gamma(flare_img)
506
+ flare_img = remove_background(flare_img)
507
+
508
+ if self.light_flag:
509
+ light_path = self.light_list[choice_dataset][choice_index]
510
+ light_img = Image.open(light_path).convert("RGB")
511
+ light_img = to_tensor(light_img)
512
+ light_img = adjust_gamma(light_img)
513
+
514
+ if self.transform_flare is not None:
515
+ if self.light_flag:
516
+ flare_merge = torch.cat((flare_img, light_img), dim=0)
517
+
518
+ if flare_nums == 1:
519
+ dx = random.randint(-224, 224)
520
+ dy = random.randint(-224, 224)
521
+ else:
522
+ dx = random.randint(
523
+ position[axis[fn]][0][0], position[axis[fn]][0][1]
524
+ )
525
+ dy = random.randint(
526
+ position[axis[fn]][1][0], position[axis[fn]][1][1]
527
+ )
528
+ if -160 < dx < 160 and -160 < dy < 160:
529
+ if random.random() < 0.5:
530
+ dx = 160 if dx > 0 else -160
531
+ else:
532
+ dy = 160 if dy > 0 else -160
533
+
534
+ flare_merge = self.transform_flare(flare_merge)
535
+ flare_merge = TF.affine(
536
+ flare_merge, angle=0, translate=(dx, dy), scale=1.0, shear=0
537
+ )
538
+ flare_merge = TF.center_crop(
539
+ flare_merge, (self.img_size, self.img_size)
540
+ )
541
+ else:
542
+ flare_img = self.transform_flare(flare_img)
543
+
544
+ # change color
545
+ if self.light_flag:
546
+ flare_img, light_img = torch.split(flare_merge, 3, dim=0)
547
+ else:
548
+ flare_img = color_jitter(flare_img)
549
+
550
+ flare_imgs.append(flare_img)
551
+ if self.light_flag:
552
+ light_img = torch.clamp(light_img, min=0, max=1)
553
+ light_imgs.append(light_img)
554
+
555
+ flare_img = torch.sum(torch.stack(flare_imgs), dim=0)
556
+ flare_img = torch.clamp(flare_img, min=0, max=1)
557
+
558
+ # flare blur
559
+ blur_transform = transforms.GaussianBlur(21, sigma=(0.1, 3.0))
560
+ flare_img = blur_transform(flare_img)
561
+ flare_img = torch.clamp(flare_img, min=0, max=1)
562
+
563
+ merge_img = torch.clamp(flare_img + base_img, min=0, max=1)
564
+
565
+ if self.light_flag:
566
+ light_img = torch.sum(torch.stack(light_imgs), dim=0)
567
+ light_img = torch.clamp(light_img, min=0, max=1)
568
+ base_img = torch.clamp(base_img + light_img, min=0, max=1)
569
+ flare_img = torch.clamp(flare_img - light_img, min=0, max=1)
570
+
571
+ flare_mask = None
572
+ if self.mask_type == None:
573
+ return {
574
+ "gt": adjust_gamma_reverse(base_img),
575
+ "flare": adjust_gamma_reverse(flare_img),
576
+ "lq": adjust_gamma_reverse(merge_img),
577
+ "gamma": gamma,
578
+ }
579
+
580
+ elif self.mask_type == "light":
581
+ one = torch.ones_like(base_img)
582
+ zero = torch.zeros_like(base_img)
583
+ threshold_value = 0.01
584
+
585
+ # flare_masks_list = []
586
+ XYRs = torch.zeros((self.num_lights, 4))
587
+ for i in range(flare_nums):
588
+ luminance = (
589
+ 0.3 * light_imgs[i][0]
590
+ + 0.59 * light_imgs[i][1]
591
+ + 0.11 * light_imgs[i][2]
592
+ )
593
+ flare_mask = torch.where(luminance > threshold_value, one, zero)
594
+
595
+ light_source_cond = (flare_mask.sum(dim=0) > 0).float()
596
+
597
+ x, y, r = self.find_circle_properties(light_source_cond, i)
598
+ XYRs[i] = torch.tensor([x, y, r, 1.0])
599
+
600
+ XYRs[:, :3] = XYRs[:, :3] / self.img_size
601
+
602
+ luminance = 0.3 * light_img[0] + 0.59 * light_img[1] + 0.11 * light_img[2]
603
+ flare_mask = torch.where(luminance > threshold_value, one, zero)
604
+
605
+ light_source_cond = (flare_mask.sum(dim=0) > 0).float()
606
+
607
+ light_source_cond = torch.repeat_interleave(
608
+ light_source_cond[None, ...], 1, dim=0
609
+ )
610
+
611
+ # box = self.crop(light_source_cond[0])
612
+ box = self.lightsource_crop(light_source_cond[0])
613
+
614
+ # random int between 0 ~ 15
615
+ margin = random.randint(0, 15)
616
+ if box[0] - margin >= 0:
617
+ box[0] -= margin
618
+ if box[1] + margin < self.img_size:
619
+ box[1] += margin
620
+ if box[2] - margin >= 0:
621
+ box[2] -= margin
622
+ if box[3] + margin < self.img_size:
623
+ box[3] += margin
624
+
625
+ top, bottom, left, right = box[2], box[3], box[0], box[1]
626
+
627
+ merge_img = adjust_gamma_reverse(merge_img)
628
+
629
+ cropped_mask = torch.full(
630
+ (self.img_size, self.img_size), True, dtype=torch.bool
631
+ )
632
+ cropped_mask[top : bottom + 1, left : right + 1] = False
633
+ channel3_mask = cropped_mask.unsqueeze(0).expand(3, -1, -1)
634
+
635
+ masked_img = merge_img * (1 - channel3_mask.float())
636
+ masked_img[channel3_mask] = 0.5
637
+
638
+ return {
639
+ # add
640
+ "input": self.normalize(masked_img), # normalize to [-1, 1]
641
+ "light_masks": light_source_cond,
642
+ "xyrs": XYRs,
643
+ }
644
+
645
+ def find_circle_properties(self, mask, i, method="minEnclosingCircle"):
646
+ """
647
+ Find the properties of the light source circle in the mask.
648
+ """
649
+
650
+ _mask = (mask.numpy() * 255).astype(np.uint8)
651
+ _, binary_mask = cv2.threshold(_mask, 127, 255, cv2.THRESH_BINARY)
652
+ contours, _ = cv2.findContours(
653
+ binary_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
654
+ )
655
+
656
+ if len(contours) == 0:
657
+ return 0.0, 0.0, 0.0
658
+
659
+ largest_contour = max(contours, key=cv2.contourArea)
660
+
661
+ if method == "minEnclosingCircle":
662
+ (x, y), radius = cv2.minEnclosingCircle(largest_contour)
663
+
664
+ elif method == "area_based":
665
+ M = cv2.moments(largest_contour)
666
+ if M["m00"] == 0: # if the contour is too small
667
+ return 0.0, 0.0, 0.0
668
+
669
+ x = M["m10"] / M["m00"]
670
+ y = M["m01"] / M["m00"]
671
+ area = cv2.contourArea(largest_contour)
672
+ radius = np.sqrt(area / np.pi)
673
+
674
+ # # draw
675
+ # cv2.circle(_mask, (int(x), int(y)), int(radius), 128, 2)
676
+ # cv2.imwrite(f"mask_{i}.png", _mask)
677
+
678
+ return x, y, radius
679
+
680
+
681
+ class Lightsource_3Maps_Loader(Lightsource_Regress_Loader):
682
+ def __init__(self, config, num_lights=4):
683
+ Lightsource_Regress_Loader.__init__(self, config, num_lights=num_lights)
684
+
685
+ def build_gt_maps(self, coords, radii, H, W, kappa=0.4):
686
+ yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij")
687
+ prob_gt = torch.zeros((H, W))
688
+ rad_gt = torch.zeros((H, W))
689
+
690
+ eps = 1e-6
691
+ for x_i, y_i, r_i in zip(coords[:, 0], coords[:, 1], radii):
692
+ if r_i < 1.0:
693
+ continue
694
+
695
+ sigma = kappa * r_i
696
+ g = torch.exp(-((xx - x_i) ** 2 + (yy - y_i) ** 2) / (2 * sigma**2))
697
+ g_prime = torch.exp(
698
+ -((xx - x_i) ** 2 + (yy - y_i) ** 2) / (2 * (sigma / 1.414) ** 2)
699
+ )
700
+ prob_gt = torch.maximum(prob_gt, g)
701
+ rad_gt = torch.maximum(rad_gt, g_prime * r_i)
702
+
703
+ rad_gt = rad_gt / (prob_gt + eps)
704
+ return prob_gt, rad_gt
705
+
706
+ def __getitem__(self, index):
707
+ # load base image
708
+ img_path = self.data_list[index]
709
+ base_img = Image.open(img_path).convert("RGB")
710
+
711
+ gamma = np.random.uniform(1.8, 2.2)
712
+ to_tensor = transforms.ToTensor()
713
+ adjust_gamma = RandomGammaCorrection(gamma)
714
+ adjust_gamma_reverse = RandomGammaCorrection(1 / gamma)
715
+ color_jitter = transforms.ColorJitter(brightness=(0.8, 3), hue=0.0)
716
+
717
+ base_img = to_tensor(base_img)
718
+ base_img = adjust_gamma(base_img)
719
+ if self.transform_base is not None:
720
+ base_img = self.transform_base(base_img)
721
+
722
+ sigma_chi = 0.01 * np.random.chisquare(df=1)
723
+ base_img = Normal(base_img, sigma_chi).sample()
724
+ gain = np.random.uniform(0.5, 1.2)
725
+ base_img = gain * base_img
726
+ base_img = torch.clamp(base_img, min=0, max=1)
727
+
728
+ # init flare and light imgs
729
+ flare_imgs = []
730
+ light_imgs = []
731
+ position = [
732
+ [[-224, 0], [-224, 0]],
733
+ [[-224, 0], [0, 224]],
734
+ [[0, 224], [-224, 0]],
735
+ [[0, 224], [0, 224]],
736
+ ]
737
+ axis = random.sample(range(4), 4)
738
+ axis[-1] = axis[0]
739
+ flare_nums = int(
740
+ random.random() * self.num_lights + 1
741
+ ) # random number of flares from 1 to 4
742
+
743
+ for fn in range(flare_nums):
744
+ choice_dataset = random.choices(
745
+ [i for i in range(len(self.flare_list))], self.data_ratio
746
+ )[0]
747
+ choice_index = random.randint(0, len(self.flare_list[choice_dataset]) - 1)
748
+
749
+ flare_path = self.flare_list[choice_dataset][choice_index]
750
+ flare_img = Image.open(flare_path).convert("RGB")
751
+ flare_img = to_tensor(flare_img)
752
+ flare_img = adjust_gamma(flare_img)
753
+ flare_img = remove_background(flare_img)
754
+
755
+ if self.light_flag:
756
+ light_path = self.light_list[choice_dataset][choice_index]
757
+ light_img = Image.open(light_path).convert("RGB")
758
+ light_img = to_tensor(light_img)
759
+ light_img = adjust_gamma(light_img)
760
+
761
+ if self.transform_flare is not None:
762
+ if self.light_flag:
763
+ flare_merge = torch.cat((flare_img, light_img), dim=0)
764
+
765
+ if flare_nums == 1:
766
+ dx = random.randint(-224, 224)
767
+ dy = random.randint(-224, 224)
768
+ else:
769
+ dx = random.randint(
770
+ position[axis[fn]][0][0], position[axis[fn]][0][1]
771
+ )
772
+ dy = random.randint(
773
+ position[axis[fn]][1][0], position[axis[fn]][1][1]
774
+ )
775
+ if -160 < dx < 160 and -160 < dy < 160:
776
+ if random.random() < 0.5:
777
+ dx = 160 if dx > 0 else -160
778
+ else:
779
+ dy = 160 if dy > 0 else -160
780
+
781
+ flare_merge = self.transform_flare(flare_merge)
782
+ flare_merge = TF.affine(
783
+ flare_merge, angle=0, translate=(dx, dy), scale=1.0, shear=0
784
+ )
785
+ flare_merge = TF.center_crop(
786
+ flare_merge, (self.img_size, self.img_size)
787
+ )
788
+ else:
789
+ flare_img = self.transform_flare(flare_img)
790
+
791
+ # change color
792
+ if self.light_flag:
793
+ flare_img, light_img = torch.split(flare_merge, 3, dim=0)
794
+ else:
795
+ flare_img = color_jitter(flare_img)
796
+
797
+ flare_imgs.append(flare_img)
798
+ if self.light_flag:
799
+ light_img = torch.clamp(light_img, min=0, max=1)
800
+ light_imgs.append(light_img)
801
+
802
+ flare_img = torch.sum(torch.stack(flare_imgs), dim=0)
803
+ flare_img = torch.clamp(flare_img, min=0, max=1)
804
+
805
+ # flare blur
806
+ blur_transform = transforms.GaussianBlur(21, sigma=(0.1, 3.0))
807
+ flare_img = blur_transform(flare_img)
808
+ flare_img = torch.clamp(flare_img, min=0, max=1)
809
+
810
+ merge_img = torch.clamp(flare_img + base_img, min=0, max=1)
811
+
812
+ if self.light_flag:
813
+ light_img = torch.sum(torch.stack(light_imgs), dim=0)
814
+ light_img = torch.clamp(light_img, min=0, max=1)
815
+ base_img = torch.clamp(base_img + light_img, min=0, max=1)
816
+ flare_img = torch.clamp(flare_img - light_img, min=0, max=1)
817
+
818
+ flare_mask = None
819
+ if self.mask_type == None:
820
+ return {
821
+ "gt": adjust_gamma_reverse(base_img),
822
+ "flare": adjust_gamma_reverse(flare_img),
823
+ "lq": adjust_gamma_reverse(merge_img),
824
+ "gamma": gamma,
825
+ }
826
+
827
+ elif self.mask_type == "light":
828
+ one = torch.ones_like(base_img)
829
+ zero = torch.zeros_like(base_img)
830
+ threshold_value = 0.01
831
+
832
+ # flare_masks_list = []
833
+ XYRs = torch.zeros((self.num_lights, 4))
834
+ for i in range(flare_nums):
835
+ luminance = (
836
+ 0.3 * light_imgs[i][0]
837
+ + 0.59 * light_imgs[i][1]
838
+ + 0.11 * light_imgs[i][2]
839
+ )
840
+ flare_mask = torch.where(luminance > threshold_value, one, zero)
841
+
842
+ light_source_cond = (flare_mask.sum(dim=0) > 0).float()
843
+
844
+ x, y, r = self.find_circle_properties(light_source_cond, i)
845
+ XYRs[i] = torch.tensor([x, y, r, 1.0])
846
+
847
+ gt_prob, gt_rad = self.build_gt_maps(
848
+ XYRs[:, :2], XYRs[:, 2], self.img_size, self.img_size
849
+ )
850
+ gt_prob = gt_prob.unsqueeze(0) # shape: (1, H, W)
851
+ gt_rad = gt_rad.unsqueeze(0)
852
+ gt_rad /= self.img_size
853
+ gt_maps = torch.cat((gt_prob, gt_rad), dim=0) # shape: (2, H, W)
854
+
855
+ XYRs[:, :3] = XYRs[:, :3] / self.img_size
856
+
857
+ luminance = 0.3 * light_img[0] + 0.59 * light_img[1] + 0.11 * light_img[2]
858
+ flare_mask = torch.where(luminance > threshold_value, one, zero)
859
+
860
+ light_source_cond = (flare_mask.sum(dim=0) > 0).float()
861
+
862
+ light_source_cond = torch.repeat_interleave(
863
+ light_source_cond[None, ...], 1, dim=0
864
+ )
865
+
866
+ # box = self.crop(light_source_cond[0])
867
+ box = self.lightsource_crop(light_source_cond[0])
868
+
869
+ # random int between 0 ~ 15
870
+ margin = random.randint(0, 15)
871
+ if box[0] - margin >= 0:
872
+ box[0] -= margin
873
+ if box[1] + margin < self.img_size:
874
+ box[1] += margin
875
+ if box[2] - margin >= 0:
876
+ box[2] -= margin
877
+ if box[3] + margin < self.img_size:
878
+ box[3] += margin
879
+
880
+ top, bottom, left, right = box[2], box[3], box[0], box[1]
881
+
882
+ merge_img = adjust_gamma_reverse(merge_img)
883
+
884
+ cropped_mask = torch.full(
885
+ (self.img_size, self.img_size), True, dtype=torch.bool
886
+ )
887
+ cropped_mask[top : bottom + 1, left : right + 1] = False
888
+ channel3_mask = cropped_mask.unsqueeze(0).expand(3, -1, -1)
889
+
890
+ masked_img = merge_img * (1 - channel3_mask.float())
891
+ masked_img[channel3_mask] = 0.5
892
+
893
+ return {
894
+ # add
895
+ "input": self.normalize(masked_img), # normalize to [-1, 1]
896
+ "light_masks": light_source_cond,
897
+ "xyrs": gt_maps,
898
+ }
899
+
900
+
901
+ class TestImageLoader(Dataset):
902
+ def __init__(
903
+ self,
904
+ dataroot_gt,
905
+ dataroot_input,
906
+ dataroot_mask,
907
+ margin=0,
908
+ img_size=512,
909
+ noise_matching=False,
910
+ ):
911
+ super(TestImageLoader, self).__init__()
912
+ self.gt_folder = dataroot_gt
913
+ self.input_folder = dataroot_input
914
+ self.mask_folder = dataroot_mask
915
+ self.paths = glod_from_folder(
916
+ [self.input_folder, self.gt_folder, self.mask_folder],
917
+ ["input", "gt", "mask"],
918
+ )
919
+
920
+ self.margin = margin
921
+ self.img_size = img_size
922
+ self.noise_matching = noise_matching
923
+
924
+ def __len__(self):
925
+ return len(self.paths["input"])
926
+
927
+ def __getitem__(self, index):
928
+ img_name = self.paths["input"][index].split("/")[-1]
929
+ num = img_name.split("_")[1].split(".")[0]
930
+
931
+ # preprocess light source mask
932
+ light_mask = np.array(Image.open(self.paths["mask"][index]))
933
+ tmp_light_mask = np.zeros_like(light_mask[:, :, 0])
934
+ tmp_light_mask[light_mask[:, :, 2] > 0] = 255
935
+ cond = (light_mask[:, :, 0] > 0) & (light_mask[:, :, 1] > 0)
936
+ tmp_light_mask[cond] = 0
937
+ light_mask = tmp_light_mask
938
+
939
+ # img for controlnet input
940
+ control_img = np.repeat(light_mask[:, :, None], 3, axis=2)
941
+
942
+ # crop region
943
+ box = self.lightsource_crop(light_mask)
944
+
945
+ if box[0] - self.margin >= 0:
946
+ box[0] -= self.margin
947
+ if box[1] + self.margin < self.img_size:
948
+ box[1] += self.margin
949
+ if box[2] - self.margin >= 0:
950
+ box[2] -= self.margin
951
+ if box[3] + self.margin < self.img_size:
952
+ box[3] += self.margin
953
+
954
+ # input image to be outpainted
955
+ input_img = np.array(Image.open(self.paths["input"][index]))
956
+ cropped_region = np.ones((self.img_size, self.img_size), dtype=np.uint8)
957
+ cropped_region[box[2] : box[3] + 1, box[0] : box[1] + 1] = 0
958
+ input_img[cropped_region == 1] = 128
959
+
960
+ # image for blip
961
+ blip_img = input_img[box[2] : box[3] + 1, box[0] : box[1] + 1, :]
962
+
963
+ # noise matching
964
+ input_img_matching = None
965
+ if self.noise_matching:
966
+ np_src_img = input_img / 255.0
967
+ np_mask_rgb = np.repeat(cropped_region[:, :, None], 3, axis=2).astype(
968
+ np.float32
969
+ )
970
+ matched_noise = self.get_matched_noise(np_src_img, np_mask_rgb)
971
+ input_img_matching = (matched_noise * 255).astype(np.uint8)
972
+
973
+ # mask image
974
+ mask_img = (cropped_region * 255).astype(np.uint8)
975
+
976
+ return {
977
+ "blip_img": blip_img,
978
+ "input_img": Image.fromarray(input_img),
979
+ "input_img_matching": (
980
+ Image.fromarray(input_img_matching)
981
+ if input_img_matching is not None
982
+ else Image.fromarray(input_img)
983
+ ),
984
+ "mask_img": Image.fromarray(mask_img),
985
+ "control_img": Image.fromarray(control_img),
986
+ "box": box,
987
+ "output_name": "output_" + num + ".png",
988
+ }
989
+
990
+ def lightsource_crop(self, matrix):
991
+ """Find the largest rectangle of 1s in a binary matrix."""
992
+
993
+ def largestRectangleArea(heights):
994
+ heights.append(0)
995
+ stack = [-1]
996
+ max_area = 0
997
+ max_rectangle = (0, 0, 0, 0) # (area, left, right, height)
998
+ for i in range(len(heights)):
999
+ while heights[i] < heights[stack[-1]]:
1000
+ h = heights[stack.pop()]
1001
+ w = i - stack[-1] - 1
1002
+ area = h * w
1003
+ if area > max_area:
1004
+ max_area = area
1005
+ max_rectangle = (area, stack[-1] + 1, i - 1, h)
1006
+ stack.append(i)
1007
+ heights.pop()
1008
+ return max_rectangle
1009
+
1010
+ max_area = 0
1011
+ max_rectangle = [0, 0, 0, 0] # (left, right, top, bottom)
1012
+ heights = [0] * len(matrix[0])
1013
+ for row in range(len(matrix)):
1014
+ for i, val in enumerate(matrix[row]):
1015
+ heights[i] = heights[i] + 1 if val == 0 else 0
1016
+
1017
+ area, left, right, height = largestRectangleArea(heights)
1018
+ if area > max_area:
1019
+ max_area = area
1020
+ max_rectangle = [int(left), int(right), int(row - height + 1), int(row)]
1021
+
1022
+ return list(max_rectangle)
1023
+
1024
+ # this function is taken from https://github.com/parlance-zz/g-diffuser-bot
1025
+ def get_matched_noise(
1026
+ self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05
1027
+ ):
1028
+ # helper fft routines that keep ortho normalization and auto-shift before and after fft
1029
+ def _fft2(data):
1030
+ if data.ndim > 2: # has channels
1031
+ out_fft = np.zeros(
1032
+ (data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
1033
+ )
1034
+ for c in range(data.shape[2]):
1035
+ c_data = data[:, :, c]
1036
+ out_fft[:, :, c] = np.fft.fft2(
1037
+ np.fft.fftshift(c_data), norm="ortho"
1038
+ )
1039
+ out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
1040
+ else: # one channel
1041
+ out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
1042
+ out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
1043
+ out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
1044
+
1045
+ return out_fft
1046
+
1047
+ def _ifft2(data):
1048
+ if data.ndim > 2: # has channels
1049
+ out_ifft = np.zeros(
1050
+ (data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
1051
+ )
1052
+ for c in range(data.shape[2]):
1053
+ c_data = data[:, :, c]
1054
+ out_ifft[:, :, c] = np.fft.ifft2(
1055
+ np.fft.fftshift(c_data), norm="ortho"
1056
+ )
1057
+ out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
1058
+ else: # one channel
1059
+ out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
1060
+ out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
1061
+ out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
1062
+
1063
+ return out_ifft
1064
+
1065
+ def _get_gaussian_window(width, height, std=3.14, mode=0):
1066
+ window_scale_x = float(width / min(width, height))
1067
+ window_scale_y = float(height / min(width, height))
1068
+
1069
+ window = np.zeros((width, height))
1070
+ x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
1071
+ for y in range(height):
1072
+ fy = (y / height * 2.0 - 1.0) * window_scale_y
1073
+ if mode == 0:
1074
+ window[:, y] = np.exp(-(x**2 + fy**2) * std)
1075
+ else:
1076
+ window[:, y] = (1 / ((x**2 + 1.0) * (fy**2 + 1.0))) ** (
1077
+ std / 3.14
1078
+ ) # hey wait a minute that's not gaussian
1079
+
1080
+ return window
1081
+
1082
+ def _get_masked_window_rgb(np_mask_grey, hardness=1.0):
1083
+ np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
1084
+ if hardness != 1.0:
1085
+ hardened = np_mask_grey[:] ** hardness
1086
+ else:
1087
+ hardened = np_mask_grey[:]
1088
+ for c in range(3):
1089
+ np_mask_rgb[:, :, c] = hardened[:]
1090
+ return np_mask_rgb
1091
+
1092
+ width = _np_src_image.shape[0]
1093
+ height = _np_src_image.shape[1]
1094
+ num_channels = _np_src_image.shape[2]
1095
+
1096
+ _np_src_image[:] * (1.0 - np_mask_rgb)
1097
+ np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
1098
+ img_mask = np_mask_grey > 1e-6
1099
+ ref_mask = np_mask_grey < 1e-3
1100
+
1101
+ windowed_image = _np_src_image * (1.0 - _get_masked_window_rgb(np_mask_grey))
1102
+ windowed_image /= np.max(windowed_image)
1103
+ windowed_image += (
1104
+ np.average(_np_src_image) * np_mask_rgb
1105
+ ) # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
1106
+
1107
+ src_fft = _fft2(windowed_image) # get feature statistics from masked src img
1108
+ src_dist = np.absolute(src_fft)
1109
+ src_phase = src_fft / src_dist
1110
+
1111
+ # create a generator with a static seed to make outpainting deterministic / only follow global seed
1112
+ rng = np.random.default_rng(0)
1113
+
1114
+ noise_window = _get_gaussian_window(
1115
+ width, height, mode=1
1116
+ ) # start with simple gaussian noise
1117
+ noise_rgb = rng.random((width, height, num_channels))
1118
+ noise_grey = np.sum(noise_rgb, axis=2) / 3.0
1119
+ noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
1120
+ for c in range(num_channels):
1121
+ noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
1122
+
1123
+ noise_fft = _fft2(noise_rgb)
1124
+ for c in range(num_channels):
1125
+ noise_fft[:, :, c] *= noise_window
1126
+ noise_rgb = np.real(_ifft2(noise_fft))
1127
+ shaped_noise_fft = _fft2(noise_rgb)
1128
+ shaped_noise_fft[:, :, :] = (
1129
+ np.absolute(shaped_noise_fft[:, :, :]) ** 2
1130
+ * (src_dist**noise_q)
1131
+ * src_phase
1132
+ ) # perform the actual shaping
1133
+
1134
+ brightness_variation = 0.0 # color_variation # todo: temporarily tying brightness variation to color variation for now
1135
+ contrast_adjusted_np_src = (
1136
+ _np_src_image[:] * (brightness_variation + 1.0) - brightness_variation * 2.0
1137
+ )
1138
+
1139
+ # scikit-image is used for histogram matching, very convenient!
1140
+ shaped_noise = np.real(_ifft2(shaped_noise_fft))
1141
+ shaped_noise -= np.min(shaped_noise)
1142
+ shaped_noise /= np.max(shaped_noise)
1143
+ shaped_noise[img_mask, :] = skimage.exposure.match_histograms(
1144
+ shaped_noise[img_mask, :] ** 1.0,
1145
+ contrast_adjusted_np_src[ref_mask, :],
1146
+ channel_axis=1,
1147
+ )
1148
+ shaped_noise = (
1149
+ _np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
1150
+ )
1151
+
1152
+ matched_noise = shaped_noise[:]
1153
+
1154
+ return np.clip(matched_noise, 0.0, 1.0)
1155
+
1156
+
1157
+ class CustomImageLoader(Dataset):
1158
+ def __init__(
1159
+ self, dataroot_input, left_outpaint, right_outpaint, up_outpaint, down_outpaint
1160
+ ):
1161
+ self.dataroot_input = dataroot_input
1162
+ self.left_outpaint = left_outpaint
1163
+ self.right_outpaint = right_outpaint
1164
+ self.up_outpaint = up_outpaint
1165
+ self.down_outpaint = down_outpaint
1166
+
1167
+ self.H = 512 - (up_outpaint + down_outpaint)
1168
+ self.W = 512 - (left_outpaint + right_outpaint)
1169
+ self.img_size = 512
1170
+
1171
+ self.img_lists = [
1172
+ os.path.join(dataroot_input, f)
1173
+ for f in os.listdir(dataroot_input)
1174
+ if f.endswith(".png") or f.endswith(".jpg")
1175
+ ]
1176
+
1177
+ def __len__(self):
1178
+ return len(self.img_lists)
1179
+
1180
+ def __getitem__(self, index):
1181
+ img_name = self.img_lists[index].split("/")[-1]
1182
+
1183
+ # crop region
1184
+ box = [
1185
+ self.left_outpaint,
1186
+ 511 - self.right_outpaint,
1187
+ self.up_outpaint,
1188
+ 511 - self.down_outpaint,
1189
+ ] # [left, right, top, bottom]
1190
+
1191
+ # box = self.lightsource_crop(light_mask)
1192
+ # if box[0] - self.margin >= 0:
1193
+ # box[0] -= self.margin
1194
+ # if box[1] + self.margin < self.img_size:
1195
+ # box[1] += self.margin
1196
+ # if box[2] - self.margin >= 0:
1197
+ # box[2] -= self.margin
1198
+ # if box[3] + self.margin < self.img_size:
1199
+ # box[3] += self.margin
1200
+
1201
+ # input image to be outpainted
1202
+ input_img = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
1203
+ paste_img = np.array(
1204
+ Image.open(self.img_lists[index]).resize((self.W, self.H), Image.LANCZOS)
1205
+ )
1206
+ input_img[box[2] : box[3] + 1, box[0] : box[1] + 1, :] = paste_img
1207
+ cropped_region = np.ones((self.img_size, self.img_size), dtype=np.uint8)
1208
+ cropped_region[box[2] : box[3] + 1, box[0] : box[1] + 1] = 0
1209
+ input_img[cropped_region == 1] = 128
1210
+
1211
+ # image for blip
1212
+ blip_img = np.array(Image.open(self.img_lists[index]))
1213
+
1214
+ # # noise matching
1215
+ # input_img_matching = None
1216
+ # if self.noise_matching:
1217
+ # np_src_img = input_img / 255.0
1218
+ # np_mask_rgb = np.repeat(cropped_region[:, :, None], 3, axis=2).astype(
1219
+ # np.float32
1220
+ # )
1221
+ # matched_noise = self.get_matched_noise(np_src_img, np_mask_rgb)
1222
+ # input_img_matching = (matched_noise * 255).astype(np.uint8)
1223
+
1224
+ # mask image
1225
+ mask_img = (cropped_region * 255).astype(np.uint8)
1226
+
1227
+ return {
1228
+ "blip_img": blip_img,
1229
+ "input_img": Image.fromarray(input_img),
1230
+ # "input_img": (
1231
+ # Image.fromarray(input_img_matching)
1232
+ # if input_img_matching is not None
1233
+ # else Image.fromarray(input_img)
1234
+ # ),
1235
+ "mask_img": Image.fromarray(mask_img),
1236
+ "box": box,
1237
+ "output_name": img_name,
1238
+ }
1239
+
1240
+
1241
+
1242
+ class HFCustomImageLoader(Dataset):
1243
+ def __init__(
1244
+ self, img_data, left_outpaint=64, right_outpaint=64, up_outpaint=64, down_outpaint=64
1245
+ ):
1246
+ self.left_outpaint = left_outpaint
1247
+ self.right_outpaint = right_outpaint
1248
+ self.up_outpaint = up_outpaint
1249
+ self.down_outpaint = down_outpaint
1250
+
1251
+ self.H = 512 - (up_outpaint + down_outpaint)
1252
+ self.W = 512 - (left_outpaint + right_outpaint)
1253
+ self.img_size = 512
1254
+
1255
+ self.img_lists = [img_data]
1256
+
1257
+ def __len__(self):
1258
+ return len(self.img_lists)
1259
+
1260
+ def __getitem__(self, index):
1261
+ # img_name = self.img_lists[index].split("/")[-1]
1262
+
1263
+ # crop region
1264
+ box = [
1265
+ self.left_outpaint,
1266
+ 511 - self.right_outpaint,
1267
+ self.up_outpaint,
1268
+ 511 - self.down_outpaint,
1269
+ ] # [left, right, top, bottom]
1270
+
1271
+ # input image to be outpainted
1272
+ input_img = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
1273
+ paste_img = np.array(self.img_lists[index].resize((self.W, self.H), Image.LANCZOS))
1274
+ input_img[box[2] : box[3] + 1, box[0] : box[1] + 1, :] = paste_img
1275
+ cropped_region = np.ones((self.img_size, self.img_size), dtype=np.uint8)
1276
+ cropped_region[box[2] : box[3] + 1, box[0] : box[1] + 1] = 0
1277
+ input_img[cropped_region == 1] = 128
1278
+
1279
+ # image for blip
1280
+ blip_img = np.array(self.img_lists[index])
1281
+
1282
+ # # noise matching
1283
+ # input_img_matching = None
1284
+ # if self.noise_matching:
1285
+ # np_src_img = input_img / 255.0
1286
+ # np_mask_rgb = np.repeat(cropped_region[:, :, None], 3, axis=2).astype(
1287
+ # np.float32
1288
+ # )
1289
+ # matched_noise = self.get_matched_noise(np_src_img, np_mask_rgb)
1290
+ # input_img_matching = (matched_noise * 255).astype(np.uint8)
1291
+
1292
+ # mask image
1293
+ mask_img = (cropped_region * 255).astype(np.uint8)
1294
+
1295
+ return {
1296
+ "blip_img": blip_img,
1297
+ "input_img": Image.fromarray(input_img),
1298
+ "mask_img": Image.fromarray(mask_img),
1299
+ "box": box,
1300
+ }
1301
+
1302
+
1303
+ if __name__ == "__main__":
1304
+ pass
utils/loss.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from scipy.optimize import linear_sum_assignment
5
+
6
+
7
+ class uncertainty_light_pos_loss(nn.Module):
8
+ def __init__(self):
9
+ super(uncertainty_light_pos_loss, self).__init__()
10
+ self.log_var_xyr = nn.Parameter(torch.tensor(1.0, requires_grad=True))
11
+ self.log_var_p = nn.Parameter(torch.tensor(1.0, requires_grad=True))
12
+
13
+ def forward(self, logits, targets):
14
+ B, N, P = logits.shape # (B, 4, 4)
15
+
16
+ position_loss = 0
17
+ confidence_loss = 0
18
+
19
+ w_xyr = 0.5 / (self.log_var_xyr**2) # uncertainty weight for position loss
20
+ w_p = 0.5 / (self.log_var_p**2) # uncertainty weight for confidence loss
21
+ weights = torch.tensor([1, 1, 2], device=logits.device) # weights for x, y, r
22
+
23
+ for b in range(B):
24
+ pred_xyr = logits[b, :, :3] # (N, 3)
25
+ pred_p = logits[b, :, 3] # (N,)
26
+
27
+ gt_xyr = targets[b, :, :3] # (N, 3)
28
+ gt_p = targets[b, :, 3] # (N,)
29
+
30
+ cost_matrix = torch.cdist(gt_xyr, pred_xyr, p=2) # (N, N)
31
+
32
+ with torch.no_grad():
33
+ row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().numpy())
34
+
35
+ matched_pred_xyr = pred_xyr[col_ind]
36
+ matched_gt_xyr = gt_xyr[row_ind]
37
+ matched_pred_p = pred_p[col_ind]
38
+ matched_gt_p = gt_p[row_ind]
39
+
40
+ valid_mask = matched_gt_p > 0
41
+ valid_cnt = valid_mask.sum().clamp(min=1)
42
+
43
+ xyr_loss = (
44
+ F.smooth_l1_loss(
45
+ matched_pred_xyr[valid_mask],
46
+ matched_gt_xyr[valid_mask],
47
+ reduction="none",
48
+ )
49
+ * weights
50
+ ).sum()
51
+
52
+ p_loss = F.binary_cross_entropy(
53
+ matched_pred_p, matched_gt_p, reduction="mean"
54
+ )
55
+
56
+ position_loss += xyr_loss / valid_cnt
57
+ confidence_loss += p_loss
58
+
59
+ position_loss = w_xyr * (position_loss / B) + torch.log(1 + self.log_var_xyr**2)
60
+ confidence_loss = w_p * (confidence_loss / B) + torch.log(1 + self.log_var_p**2)
61
+
62
+ return position_loss, confidence_loss
63
+
64
+
65
+ class unet_3maps_loss(nn.Module):
66
+ def __init__(self):
67
+ super(unet_3maps_loss, self).__init__()
68
+
69
+ def forward(self, pred_prob, pred_rad, prob_gt, rad_gt):
70
+ focal = nn.BCELoss()
71
+ L_prob = focal(pred_prob, prob_gt)
72
+
73
+ pos_mask = prob_gt > 0.5
74
+ L_rad = (
75
+ nn.functional.smooth_l1_loss(pred_rad[pos_mask], rad_gt[pos_mask])
76
+ if pos_mask.any()
77
+ else pred_rad.sum() * 0
78
+ )
79
+
80
+ return L_prob + 10.0 * L_rad, L_prob, L_rad
utils/utils.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import skimage
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms as transforms
8
+ import torchvision.transforms.functional as TF
9
+ from PIL import Image
10
+ from skimage.draw import disk
11
+ from skimage import morphology
12
+ from collections import OrderedDict
13
+
14
+
15
+ def load_mfdnet_checkpoint(model, weights):
16
+ checkpoint = torch.load(weights, map_location=lambda storage, loc: storage.cuda(0))
17
+ new_state_dict = OrderedDict()
18
+ for key, value in checkpoint["state_dict"].items():
19
+ if key.startswith("module"):
20
+ name = key[7:]
21
+ else:
22
+ name = key
23
+ new_state_dict[name] = value
24
+ model.load_state_dict(new_state_dict)
25
+
26
+
27
+ def adjust_gamma(image: torch.Tensor, gamma):
28
+ # image is in shape of [B,C,H,W] and gamma is in shape [B]
29
+ gamma = gamma.float().cuda()
30
+ gamma_tensor = torch.ones_like(image)
31
+ gamma_tensor = gamma.view(-1, 1, 1, 1) * gamma_tensor
32
+ image = torch.pow(image, gamma_tensor)
33
+ out = torch.clamp(image, 0.0, 1.0)
34
+ return out
35
+
36
+
37
+ def adjust_gamma_reverse(image: torch.Tensor, gamma):
38
+ # gamma=torch.Tensor([gamma]).cuda()
39
+ gamma = 1 / gamma.float().cuda()
40
+ gamma_tensor = torch.ones_like(image)
41
+ gamma_tensor = gamma.view(-1, 1, 1, 1) * gamma_tensor
42
+ image = torch.pow(image, gamma_tensor)
43
+ out = torch.clamp(image, 0.0, 1.0)
44
+ return out
45
+
46
+
47
+ def predict_flare_from_6_channel(input_tensor, gamma):
48
+ # the input is a tensor in [B,C,H,W], the C here is 6
49
+
50
+ deflare_img = input_tensor[:, :3, :, :]
51
+ flare_img_predicted = input_tensor[:, 3:, :, :]
52
+
53
+ merge_img_predicted_linear = adjust_gamma(deflare_img, gamma) + adjust_gamma(
54
+ flare_img_predicted, gamma
55
+ )
56
+ merge_img_predicted = adjust_gamma_reverse(
57
+ torch.clamp(merge_img_predicted_linear, 1e-7, 1.0), gamma
58
+ )
59
+ return deflare_img, flare_img_predicted, merge_img_predicted
60
+
61
+
62
+ def predict_flare_from_3_channel(
63
+ input_tensor, flare_mask, base_img, flare_img, merge_img, gamma
64
+ ):
65
+ # the input is a tensor in [B,C,H,W], the C here is 3
66
+
67
+ input_tensor_linear = adjust_gamma(input_tensor, gamma)
68
+ merge_tensor_linear = adjust_gamma(merge_img, gamma)
69
+ flare_img_predicted = adjust_gamma_reverse(
70
+ torch.clamp(merge_tensor_linear - input_tensor_linear, 1e-7, 1.0), gamma
71
+ )
72
+
73
+ masked_deflare_img = input_tensor * (1 - flare_mask) + base_img * flare_mask
74
+ masked_flare_img_predicted = (
75
+ flare_img_predicted * (1 - flare_mask) + flare_img * flare_mask
76
+ )
77
+
78
+ return masked_deflare_img, masked_flare_img_predicted
79
+
80
+
81
+ def get_highlight_mask(image, threshold=0.99, luminance_mode=False):
82
+ """Get the area close to the exposure
83
+ Args:
84
+ image: the image tensor in [B,C,H,W]. For inference, B is set as 1.
85
+ threshold: the threshold of luminance/greyscale of exposure region
86
+ luminance_mode: use luminance or greyscale
87
+ Return:
88
+ Binary image in [B,H,W]
89
+ """
90
+ if luminance_mode:
91
+ # 3 channels in RGB
92
+ luminance = (
93
+ 0.2126 * image[:, 0, :, :]
94
+ + 0.7152 * image[:, 1, :, :]
95
+ + 0.0722 * image[:, 2, :, :]
96
+ )
97
+ binary_mask = luminance > threshold
98
+ else:
99
+ binary_mask = image.mean(dim=1, keepdim=True) > threshold
100
+ binary_mask = binary_mask.to(image.dtype)
101
+ return binary_mask
102
+
103
+
104
+ def refine_mask(mask, morph_size=0.01):
105
+ """Refines a mask by applying mophological operations.
106
+ Args:
107
+ mask: A float array of shape [H, W]
108
+ morph_size: Size of the morphological kernel relative to the long side of
109
+ the image.
110
+
111
+ Returns:
112
+ Refined mask of shape [H, W].
113
+ """
114
+ mask_size = max(np.shape(mask))
115
+ kernel_radius = 0.5 * morph_size * mask_size
116
+ kernel = morphology.disk(np.ceil(kernel_radius))
117
+ opened = morphology.binary_opening(mask, kernel)
118
+ return opened
119
+
120
+
121
+ def _create_disk_kernel(kernel_size):
122
+ _EPS = 1e-7
123
+ x = np.arange(kernel_size) - (kernel_size - 1) / 2
124
+ xx, yy = np.meshgrid(x, x)
125
+ rr = np.sqrt(xx**2 + yy**2)
126
+ kernel = np.float32(rr <= np.max(x)) + _EPS
127
+ kernel = kernel / np.sum(kernel)
128
+ return kernel
129
+
130
+
131
+ def blend_light_source(input_scene, pred_scene, threshold=0.99, luminance_mode=False):
132
+ binary_mask = (
133
+ get_highlight_mask(
134
+ input_scene, threshold=threshold, luminance_mode=luminance_mode
135
+ )
136
+ > 0.5
137
+ ).to("cpu", torch.bool)
138
+ binary_mask = binary_mask.squeeze() # (h, w)
139
+ binary_mask = binary_mask.numpy()
140
+ binary_mask = refine_mask(binary_mask)
141
+
142
+ labeled = skimage.measure.label(binary_mask)
143
+ properties = skimage.measure.regionprops(labeled)
144
+ max_diameter = 0
145
+ for p in properties:
146
+ # The diameter of a circle with the same area as the region.
147
+ max_diameter = max(max_diameter, p["equivalent_diameter"])
148
+
149
+ mask = np.float32(binary_mask)
150
+ kernel_size = round(1.5 * max_diameter) # default is 1.5
151
+ if kernel_size > 0:
152
+ kernel = _create_disk_kernel(kernel_size)
153
+ mask = cv2.filter2D(mask, -1, kernel)
154
+ mask = np.clip(mask * 3.0, 0.0, 1.0)
155
+ mask_rgb = np.stack([mask] * 3, axis=0)
156
+
157
+ mask_rgb = torch.from_numpy(mask_rgb).to(input_scene.device, torch.float32)
158
+ blend = input_scene * mask_rgb + pred_scene * (1 - mask_rgb)
159
+ else:
160
+ blend = pred_scene
161
+ return blend
162
+
163
+
164
+ def blend_with_alpha(result, input_img, box, blur_size=31):
165
+ """
166
+ Apply alpha blending to paste the specified box region from input_img onto the result image
167
+ to reduce boundary artifacts and make the blending more natural.
168
+
169
+ Args:
170
+ result (np.array): inpainting generated image
171
+ input_img (np.array): original image
172
+ box (tuple): (x_min, x_max, y_min, y_max) representing the paste-back region from the original image
173
+ blur_size (int): blur range for the mask, larger values create smoother transitions (recommended 15~50)
174
+
175
+ Returns:
176
+ np.array: image after alpha blending
177
+ """
178
+
179
+ x_min, x_max, y_min, y_max = box
180
+
181
+ # alpha mask
182
+ mask = np.zeros_like(result, dtype=np.float32)
183
+ mask[y_min : y_max + 1, x_min : x_max + 1] = 1.0
184
+
185
+ # gaussian blur
186
+ mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0)
187
+
188
+ # alpha blending
189
+ blended = (mask * input_img + (1 - mask) * result).astype(np.uint8)
190
+
191
+ return blended
192
+
193
+
194
+ def IoU(pred, target):
195
+ assert pred.shape == target.shape, "Prediction and target must have the same shape."
196
+
197
+ intersection = np.logical_and(pred, target).sum()
198
+ union = np.logical_or(pred, target).sum()
199
+
200
+ if union == 0:
201
+ return 1.0 if intersection == 0 else 0.0
202
+
203
+ return intersection / union
204
+
205
+
206
+ def mean_IoU(y_true, y_pred, num_classes):
207
+ """
208
+ Calculate the mean Intersection over Union (mIoU) score.
209
+
210
+ Args:
211
+ y_true (np.ndarray): Ground truth labels (integer class values).
212
+ y_pred (np.ndarray): Predicted labels (integer class values).
213
+ num_classes (int): Number of classes.
214
+
215
+ Returns:
216
+ float: The mean IoU score across all classes.
217
+ """
218
+ iou_scores = []
219
+
220
+ for cls in range(num_classes):
221
+ # Create binary masks for the current class
222
+ true_mask = y_true == cls
223
+ pred_mask = y_pred == cls
224
+
225
+ # Calculate intersection and union
226
+ intersection = np.logical_and(true_mask, pred_mask)
227
+ union = np.logical_or(true_mask, pred_mask)
228
+
229
+ # Compute IoU for the current class
230
+ if np.sum(union) == 0:
231
+ # Handle edge case: no samples for this class
232
+ iou_scores.append(np.nan)
233
+ else:
234
+ iou_scores.append(np.sum(intersection) / np.sum(union))
235
+
236
+ # Calculate mean IoU, ignoring NaN values (classes without samples)
237
+ mean_iou = np.nanmean(iou_scores)
238
+ return mean_iou
239
+
240
+
241
+ def RGB2YCbCr(img):
242
+ img = img * 255.0
243
+ r, g, b = torch.split(img, 1, dim=0)
244
+ y = torch.zeros_like(r)
245
+ cb = torch.zeros_like(r)
246
+ cr = torch.zeros_like(r)
247
+
248
+ y = 0.257 * r + 0.504 * g + 0.098 * b + 16
249
+ y = y / 255.0
250
+
251
+ cb = -0.148 * r - 0.291 * g + 0.439 * b + 128
252
+ cb = cb / 255.0
253
+
254
+ cr = 0.439 * r - 0.368 * g - 0.071 * b + 128
255
+ cr = cr / 255.0
256
+
257
+ img = torch.cat([y, y, y], dim=0)
258
+ return img
259
+
260
+
261
+ def extract_peaks(prob_map, thr=0.5, pool=7):
262
+ """
263
+ prob_map: (H, W) after sigmoid
264
+ return: tensor of peak coordinates [K, 2] (x, y)
265
+ """
266
+ # binary mask
267
+ pos = prob_map > thr
268
+
269
+ # non‑maximum suppression
270
+ nms = F.max_pool2d(
271
+ prob_map.unsqueeze(0).unsqueeze(0),
272
+ kernel_size=pool,
273
+ stride=1,
274
+ padding=pool // 2,
275
+ )
276
+ peaks = (prob_map == nms.squeeze()) & pos
277
+ ys, xs = torch.nonzero(peaks, as_tuple=True)
278
+ return torch.stack([xs, ys], dim=1) # (K, 2)
279
+
280
+
281
+ def pick_radius(radius_map, centers, ksize=3):
282
+ """
283
+ radius_map: (H, W) ∈ [0, 1]
284
+ centers: (K, 2) x,y
285
+ return: (K,) radii in pixel
286
+ """
287
+ # H, W = radius_map.shape
288
+ pad = ksize // 2
289
+ padded = F.pad(
290
+ radius_map.unsqueeze(0).unsqueeze(0), (pad, pad, pad, pad), mode="reflect"
291
+ )
292
+
293
+ radii = []
294
+ for x, y in centers:
295
+ patch = padded[..., y : y + ksize, x : x + ksize]
296
+ radii.append(patch.mean()) # 3×3 mean
297
+ return torch.stack(radii)
298
+
299
+
300
+ def draw_mask(centers, radii, H, W):
301
+ """
302
+ centers: (K, 2) (x, y)
303
+ radii: (K,)
304
+ return: (H, W) uint8 mask
305
+ """
306
+ radii *= 256
307
+ mask = np.zeros((H, W), dtype=np.float32)
308
+ for (x, y), r in zip(centers, radii):
309
+ rr, cc = disk((y.item(), x.item()), r.item(), shape=mask.shape)
310
+ mask[rr, cc] = 1
311
+ return mask
weights/light_outpaint_lora/pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04aeb7148ae4d8c59f0d0260ee813c2fe41a8392d826c4941dfda9ed7cf7090d
3
+ size 3358448
weights/light_regress/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c4e2ac2d23180814361ec04bcb22cc92adb761fb5ccc761b5c3874a297fed18
3
+ size 85314151
weights/net_g_last.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75f0fc77ab43703c7a9c7876621f8a651d6ce3a0cfb7c6e2377b3c8e2331b0e2
3
+ size 82605273