Spaces:
Runtime error
Runtime error
Ray-1026
commited on
Commit
·
a856109
1
Parent(s):
ef36a49
update
Browse files- .gitattributes +5 -0
- SIFR_models/flare7kpp/__pycache__/model.cpython-39.pyc +0 -0
- SIFR_models/flare7kpp/model.py +0 -0
- SIFR_models/mfdnet/backbone.py +285 -0
- SIFR_models/mfdnet/blocks.py +164 -0
- SIFR_models/mfdnet/model.py +786 -0
- app.py +195 -45
- requirements.txt +15 -2
- src/models/__pycache__/light_source_regressor.cpython-39.pyc +0 -0
- src/models/__pycache__/unet.cpython-39.pyc +0 -0
- src/models/light_source_regressor.py +124 -0
- src/models/unet.py +129 -0
- src/pipelines/__pycache__/pipeline_controlnet_outpaint.cpython-39.pyc +0 -0
- src/pipelines/__pycache__/pipeline_stable_diffusion_outpaint.cpython-39.pyc +0 -0
- src/pipelines/pipeline_controlnet_outpaint.py +448 -0
- src/pipelines/pipeline_stable_diffusion_outpaint.py +517 -0
- src/schedulers/__pycache__/scheduling_pndm.cpython-39.pyc +0 -0
- src/schedulers/scheduling_pndm.py +126 -0
- utils/__pycache__/dataset.cpython-39.pyc +0 -0
- utils/__pycache__/utils.cpython-39.pyc +0 -0
- utils/dataset.py +1304 -0
- utils/loss.py +80 -0
- utils/utils.py +311 -0
- weights/light_outpaint_lora/pytorch_lora_weights.safetensors +3 -0
- weights/light_regress/model.pth +3 -0
- weights/net_g_last.pth +3 -0
.gitattributes
CHANGED
|
@@ -35,3 +35,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
assets/* filter=lfs diff=lfs merge=lfs -text
|
| 37 |
assets/exp.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
assets/* filter=lfs diff=lfs merge=lfs -text
|
| 37 |
assets/exp.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
weights/light_outpaint_lora filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
weights/light_regress filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
weights/net_g_last.pth filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
weights/light_outpaint_lora/pytorch_lora_weights.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
weights/light_regress/model.pth filter=lfs diff=lfs merge=lfs -text
|
SIFR_models/flare7kpp/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (52.1 kB). View file
|
|
|
SIFR_models/flare7kpp/model.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
SIFR_models/mfdnet/backbone.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LayerNormFunction(torch.autograd.Function):
|
| 7 |
+
|
| 8 |
+
@staticmethod
|
| 9 |
+
def forward(ctx, x, weight, bias, eps):
|
| 10 |
+
ctx.eps = eps
|
| 11 |
+
N, C, H, W = x.size()
|
| 12 |
+
mu = x.mean(1, keepdim=True)
|
| 13 |
+
var = (x - mu).pow(2).mean(1, keepdim=True)
|
| 14 |
+
y = (x - mu) / (var + eps).sqrt()
|
| 15 |
+
ctx.save_for_backward(y, var, weight)
|
| 16 |
+
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
| 17 |
+
return y
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def backward(ctx, grad_output):
|
| 21 |
+
eps = ctx.eps
|
| 22 |
+
|
| 23 |
+
N, C, H, W = grad_output.size()
|
| 24 |
+
y, var, weight = ctx.saved_variables
|
| 25 |
+
g = grad_output * weight.view(1, C, 1, 1)
|
| 26 |
+
mean_g = g.mean(dim=1, keepdim=True)
|
| 27 |
+
|
| 28 |
+
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
| 29 |
+
gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
| 30 |
+
return (
|
| 31 |
+
gx,
|
| 32 |
+
(grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
|
| 33 |
+
grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
|
| 34 |
+
None,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LayerNorm2d(nn.Module):
|
| 39 |
+
|
| 40 |
+
def __init__(self, channels, eps=1e-6):
|
| 41 |
+
super(LayerNorm2d, self).__init__()
|
| 42 |
+
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
|
| 43 |
+
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
|
| 44 |
+
self.eps = eps
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class SimpleGate(nn.Module):
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 53 |
+
return x1 * x2
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class NAFBlock(nn.Module):
|
| 57 |
+
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0):
|
| 58 |
+
super().__init__()
|
| 59 |
+
dw_channel = c * DW_Expand
|
| 60 |
+
self.conv1 = nn.Conv2d(
|
| 61 |
+
in_channels=c,
|
| 62 |
+
out_channels=dw_channel,
|
| 63 |
+
kernel_size=1,
|
| 64 |
+
padding=0,
|
| 65 |
+
stride=1,
|
| 66 |
+
groups=1,
|
| 67 |
+
bias=True,
|
| 68 |
+
)
|
| 69 |
+
self.conv2 = nn.Conv2d(
|
| 70 |
+
in_channels=dw_channel,
|
| 71 |
+
out_channels=dw_channel,
|
| 72 |
+
kernel_size=3,
|
| 73 |
+
padding=1,
|
| 74 |
+
stride=1,
|
| 75 |
+
groups=dw_channel,
|
| 76 |
+
bias=True,
|
| 77 |
+
)
|
| 78 |
+
self.conv3 = nn.Conv2d(
|
| 79 |
+
in_channels=dw_channel // 2,
|
| 80 |
+
out_channels=c,
|
| 81 |
+
kernel_size=1,
|
| 82 |
+
padding=0,
|
| 83 |
+
stride=1,
|
| 84 |
+
groups=1,
|
| 85 |
+
bias=True,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Simplified Channel Attention
|
| 89 |
+
self.sca = nn.Sequential(
|
| 90 |
+
nn.AdaptiveAvgPool2d(1),
|
| 91 |
+
nn.Conv2d(
|
| 92 |
+
in_channels=dw_channel // 2,
|
| 93 |
+
out_channels=dw_channel // 2,
|
| 94 |
+
kernel_size=1,
|
| 95 |
+
padding=0,
|
| 96 |
+
stride=1,
|
| 97 |
+
groups=1,
|
| 98 |
+
bias=True,
|
| 99 |
+
),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# SimpleGate
|
| 103 |
+
self.sg = SimpleGate()
|
| 104 |
+
|
| 105 |
+
ffn_channel = FFN_Expand * c
|
| 106 |
+
self.conv4 = nn.Conv2d(
|
| 107 |
+
in_channels=c,
|
| 108 |
+
out_channels=ffn_channel,
|
| 109 |
+
kernel_size=1,
|
| 110 |
+
padding=0,
|
| 111 |
+
stride=1,
|
| 112 |
+
groups=1,
|
| 113 |
+
bias=True,
|
| 114 |
+
)
|
| 115 |
+
self.conv5 = nn.Conv2d(
|
| 116 |
+
in_channels=ffn_channel // 2,
|
| 117 |
+
out_channels=c,
|
| 118 |
+
kernel_size=1,
|
| 119 |
+
padding=0,
|
| 120 |
+
stride=1,
|
| 121 |
+
groups=1,
|
| 122 |
+
bias=True,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.norm1 = LayerNorm2d(c)
|
| 126 |
+
self.norm2 = LayerNorm2d(c)
|
| 127 |
+
|
| 128 |
+
self.dropout1 = (
|
| 129 |
+
nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()
|
| 130 |
+
)
|
| 131 |
+
self.dropout2 = (
|
| 132 |
+
nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
| 136 |
+
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
| 137 |
+
|
| 138 |
+
def forward(self, inp):
|
| 139 |
+
x = inp
|
| 140 |
+
|
| 141 |
+
x = self.norm1(x)
|
| 142 |
+
|
| 143 |
+
x = self.conv1(x)
|
| 144 |
+
x = self.conv2(x)
|
| 145 |
+
x = self.sg(x)
|
| 146 |
+
x = x * self.sca(x)
|
| 147 |
+
x = self.conv3(x)
|
| 148 |
+
|
| 149 |
+
x = self.dropout1(x)
|
| 150 |
+
|
| 151 |
+
y = inp + x * self.beta
|
| 152 |
+
|
| 153 |
+
x = self.conv4(self.norm2(y))
|
| 154 |
+
x = self.sg(x)
|
| 155 |
+
x = self.conv5(x)
|
| 156 |
+
|
| 157 |
+
x = self.dropout2(x)
|
| 158 |
+
|
| 159 |
+
return y + x * self.gamma
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class NAFNet(nn.Module):
|
| 163 |
+
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
img_channel=3,
|
| 167 |
+
width=32,
|
| 168 |
+
middle_blk_num=12,
|
| 169 |
+
enc_blk_nums=[2, 2, 4, 8],
|
| 170 |
+
dec_blk_nums=[2, 2, 2, 2],
|
| 171 |
+
):
|
| 172 |
+
super().__init__()
|
| 173 |
+
|
| 174 |
+
self.intro = nn.Conv2d(
|
| 175 |
+
in_channels=img_channel,
|
| 176 |
+
out_channels=width,
|
| 177 |
+
kernel_size=3,
|
| 178 |
+
padding=1,
|
| 179 |
+
stride=1,
|
| 180 |
+
groups=1,
|
| 181 |
+
bias=True,
|
| 182 |
+
)
|
| 183 |
+
self.ending = nn.Conv2d(
|
| 184 |
+
in_channels=width,
|
| 185 |
+
out_channels=img_channel,
|
| 186 |
+
kernel_size=3,
|
| 187 |
+
padding=1,
|
| 188 |
+
stride=1,
|
| 189 |
+
groups=1,
|
| 190 |
+
bias=True,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
self.encoders = nn.ModuleList()
|
| 194 |
+
self.decoders = nn.ModuleList()
|
| 195 |
+
self.middle_blks = nn.ModuleList()
|
| 196 |
+
self.ups = nn.ModuleList()
|
| 197 |
+
self.downs = nn.ModuleList()
|
| 198 |
+
|
| 199 |
+
chan = width
|
| 200 |
+
for num in enc_blk_nums:
|
| 201 |
+
self.encoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
|
| 202 |
+
self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2))
|
| 203 |
+
chan = chan * 2
|
| 204 |
+
|
| 205 |
+
self.middle_blks = nn.Sequential(
|
| 206 |
+
*[NAFBlock(chan) for _ in range(middle_blk_num)]
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
for num in dec_blk_nums:
|
| 210 |
+
self.ups.append(
|
| 211 |
+
nn.Sequential(
|
| 212 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False), nn.PixelShuffle(2)
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
chan = chan // 2
|
| 216 |
+
self.decoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
|
| 217 |
+
|
| 218 |
+
self.padder_size = 2 ** len(self.encoders)
|
| 219 |
+
|
| 220 |
+
def forward(self, inp):
|
| 221 |
+
B, C, H, W = inp.shape
|
| 222 |
+
inp = self.check_image_size(inp)
|
| 223 |
+
|
| 224 |
+
x = self.intro(inp)
|
| 225 |
+
|
| 226 |
+
encs = []
|
| 227 |
+
|
| 228 |
+
for encoder, down in zip(self.encoders, self.downs):
|
| 229 |
+
x = encoder(x)
|
| 230 |
+
encs.append(x)
|
| 231 |
+
x = down(x)
|
| 232 |
+
|
| 233 |
+
x = self.middle_blks(x)
|
| 234 |
+
|
| 235 |
+
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
|
| 236 |
+
x = up(x)
|
| 237 |
+
x = x + enc_skip
|
| 238 |
+
x = decoder(x)
|
| 239 |
+
|
| 240 |
+
x = self.ending(x)
|
| 241 |
+
x = x + inp
|
| 242 |
+
|
| 243 |
+
return x[:, :, :H, :W]
|
| 244 |
+
|
| 245 |
+
def check_image_size(self, x):
|
| 246 |
+
_, _, h, w = x.size()
|
| 247 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
| 248 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
| 249 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
|
| 250 |
+
return x
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
img_channel = 3
|
| 255 |
+
width = 32
|
| 256 |
+
|
| 257 |
+
enc_blks = [2, 2, 4, 8]
|
| 258 |
+
middle_blk_num = 12
|
| 259 |
+
dec_blks = [2, 2, 2, 2]
|
| 260 |
+
|
| 261 |
+
print(
|
| 262 |
+
"enc blks",
|
| 263 |
+
enc_blks,
|
| 264 |
+
"middle blk num",
|
| 265 |
+
middle_blk_num,
|
| 266 |
+
"dec blks",
|
| 267 |
+
dec_blks,
|
| 268 |
+
"width",
|
| 269 |
+
width,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# using('start . ')
|
| 273 |
+
model = NAFNet(
|
| 274 |
+
img_channel=img_channel,
|
| 275 |
+
width=width,
|
| 276 |
+
middle_blk_num=middle_blk_num,
|
| 277 |
+
enc_blk_nums=enc_blks,
|
| 278 |
+
dec_blk_nums=dec_blks,
|
| 279 |
+
).cuda()
|
| 280 |
+
|
| 281 |
+
model.eval()
|
| 282 |
+
input = torch.randn(1, 3, 15, 22).cuda()
|
| 283 |
+
# input = torch.randn(1, 3, 32, 32)
|
| 284 |
+
y = model(input)
|
| 285 |
+
print(y.size())
|
SIFR_models/mfdnet/blocks.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ConvLayer(nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
in_channels,
|
| 10 |
+
out_channels,
|
| 11 |
+
kernel_size,
|
| 12 |
+
stride,
|
| 13 |
+
dilation=1,
|
| 14 |
+
bias=True,
|
| 15 |
+
groups=1,
|
| 16 |
+
norm="in",
|
| 17 |
+
nonlinear="relu",
|
| 18 |
+
):
|
| 19 |
+
super(ConvLayer, self).__init__()
|
| 20 |
+
reflection_padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2
|
| 21 |
+
self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
|
| 22 |
+
self.conv2d = nn.Conv2d(
|
| 23 |
+
in_channels,
|
| 24 |
+
out_channels,
|
| 25 |
+
kernel_size,
|
| 26 |
+
stride,
|
| 27 |
+
groups=groups,
|
| 28 |
+
bias=bias,
|
| 29 |
+
dilation=dilation,
|
| 30 |
+
)
|
| 31 |
+
self.norm = norm
|
| 32 |
+
self.nonlinear = nonlinear
|
| 33 |
+
|
| 34 |
+
if norm == "bn":
|
| 35 |
+
self.normalization = nn.BatchNorm2d(out_channels)
|
| 36 |
+
elif norm == "in":
|
| 37 |
+
self.normalization = nn.InstanceNorm2d(out_channels, affine=False)
|
| 38 |
+
else:
|
| 39 |
+
self.normalization = None
|
| 40 |
+
|
| 41 |
+
if nonlinear == "relu":
|
| 42 |
+
self.activation = nn.ReLU(inplace=True)
|
| 43 |
+
elif nonlinear == "leakyrelu":
|
| 44 |
+
self.activation = nn.LeakyReLU(0.2)
|
| 45 |
+
elif nonlinear == "PReLU":
|
| 46 |
+
self.activation = nn.PReLU()
|
| 47 |
+
else:
|
| 48 |
+
self.activation = None
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
out = self.conv2d(self.reflection_pad(x))
|
| 52 |
+
if self.normalization is not None:
|
| 53 |
+
out = self.normalization(out)
|
| 54 |
+
if self.activation is not None:
|
| 55 |
+
out = self.activation(out)
|
| 56 |
+
|
| 57 |
+
return out
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Aggreation(nn.Module):
|
| 61 |
+
def __init__(self, in_channels, out_channels, kernel_size=3):
|
| 62 |
+
super(Aggreation, self).__init__()
|
| 63 |
+
self.attention = SelfAttention(in_channels, k=8, nonlinear="relu")
|
| 64 |
+
self.conv = ConvLayer(
|
| 65 |
+
in_channels,
|
| 66 |
+
out_channels,
|
| 67 |
+
kernel_size=kernel_size,
|
| 68 |
+
stride=1,
|
| 69 |
+
dilation=1,
|
| 70 |
+
nonlinear="leakyrelu",
|
| 71 |
+
norm=None,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
return self.conv(self.attention(x))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SelfAttention(nn.Module):
|
| 79 |
+
def __init__(self, channels, k, nonlinear="relu"):
|
| 80 |
+
super(SelfAttention, self).__init__()
|
| 81 |
+
self.channels = channels
|
| 82 |
+
self.k = k
|
| 83 |
+
self.nonlinear = nonlinear
|
| 84 |
+
|
| 85 |
+
self.linear1 = nn.Linear(channels, channels // k)
|
| 86 |
+
self.linear2 = nn.Linear(channels // k, channels)
|
| 87 |
+
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
|
| 88 |
+
|
| 89 |
+
if nonlinear == "relu":
|
| 90 |
+
self.activation = nn.ReLU(inplace=True)
|
| 91 |
+
elif nonlinear == "leakyrelu":
|
| 92 |
+
self.activation = nn.LeakyReLU(0.2)
|
| 93 |
+
elif nonlinear == "PReLU":
|
| 94 |
+
self.activation = nn.PReLU()
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError
|
| 97 |
+
|
| 98 |
+
def attention(self, x):
|
| 99 |
+
N, C, H, W = x.size()
|
| 100 |
+
out = torch.flatten(self.global_pooling(x), 1)
|
| 101 |
+
out = self.activation(self.linear1(out))
|
| 102 |
+
out = torch.sigmoid(self.linear2(out)).view(N, C, 1, 1)
|
| 103 |
+
|
| 104 |
+
return out.mul(x)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
return self.attention(x)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class SPP(nn.Module):
|
| 111 |
+
def __init__(
|
| 112 |
+
self, in_channels, out_channels, num_layers=4, interpolation_type="bilinear"
|
| 113 |
+
):
|
| 114 |
+
super(SPP, self).__init__()
|
| 115 |
+
self.conv = nn.ModuleList()
|
| 116 |
+
self.num_layers = num_layers
|
| 117 |
+
self.interpolation_type = interpolation_type
|
| 118 |
+
|
| 119 |
+
for _ in range(self.num_layers):
|
| 120 |
+
self.conv.append(
|
| 121 |
+
ConvLayer(
|
| 122 |
+
in_channels,
|
| 123 |
+
in_channels,
|
| 124 |
+
kernel_size=1,
|
| 125 |
+
stride=1,
|
| 126 |
+
dilation=1,
|
| 127 |
+
nonlinear="leakyrelu",
|
| 128 |
+
norm=None,
|
| 129 |
+
)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.fusion = ConvLayer(
|
| 133 |
+
(in_channels * (self.num_layers + 1)),
|
| 134 |
+
out_channels,
|
| 135 |
+
kernel_size=3,
|
| 136 |
+
stride=1,
|
| 137 |
+
norm="False",
|
| 138 |
+
nonlinear="leakyrelu",
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
|
| 143 |
+
N, C, H, W = x.size()
|
| 144 |
+
out = []
|
| 145 |
+
|
| 146 |
+
for level in range(self.num_layers):
|
| 147 |
+
out.append(
|
| 148 |
+
F.interpolate(
|
| 149 |
+
self.conv[level](
|
| 150 |
+
F.avg_pool2d(
|
| 151 |
+
x,
|
| 152 |
+
kernel_size=2 * 2 ** (level + 1),
|
| 153 |
+
stride=2 * 2 ** (level + 1),
|
| 154 |
+
padding=2 * 2 ** (level + 1) % 2,
|
| 155 |
+
)
|
| 156 |
+
),
|
| 157 |
+
size=(H, W),
|
| 158 |
+
mode=self.interpolation_type,
|
| 159 |
+
)
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
out.append(x)
|
| 163 |
+
|
| 164 |
+
return self.fusion(torch.cat(out, dim=1))
|
SIFR_models/mfdnet/model.py
ADDED
|
@@ -0,0 +1,786 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numbers
|
| 2 |
+
|
| 3 |
+
import einops
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
from .backbone import *
|
| 7 |
+
from .blocks import *
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ResidualBlock(nn.Module):
|
| 11 |
+
def __init__(self, in_features):
|
| 12 |
+
super(ResidualBlock, self).__init__()
|
| 13 |
+
|
| 14 |
+
self.block = nn.Sequential(
|
| 15 |
+
nn.Conv2d(in_features, in_features, 3, padding=1),
|
| 16 |
+
nn.LeakyReLU(),
|
| 17 |
+
nn.Conv2d(in_features, in_features, 3, padding=1),
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return x + self.block(x)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def gauss_kernel(channels=3):
|
| 25 |
+
kernel = torch.tensor(
|
| 26 |
+
[
|
| 27 |
+
[1.0, 4.0, 6.0, 4.0, 1],
|
| 28 |
+
[4.0, 16.0, 24.0, 16.0, 4.0],
|
| 29 |
+
[6.0, 24.0, 36.0, 24.0, 6.0],
|
| 30 |
+
[4.0, 16.0, 24.0, 16.0, 4.0],
|
| 31 |
+
[1.0, 4.0, 6.0, 4.0, 1.0],
|
| 32 |
+
]
|
| 33 |
+
)
|
| 34 |
+
kernel /= 256.0
|
| 35 |
+
kernel = kernel.repeat(channels, 1, 1, 1)
|
| 36 |
+
return kernel
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class LapPyramidConv(nn.Module):
|
| 40 |
+
def __init__(self, num_high=4):
|
| 41 |
+
super(LapPyramidConv, self).__init__()
|
| 42 |
+
|
| 43 |
+
self.num_high = num_high
|
| 44 |
+
self.kernel = gauss_kernel()
|
| 45 |
+
|
| 46 |
+
def downsample(self, x):
|
| 47 |
+
return x[:, :, ::2, ::2]
|
| 48 |
+
|
| 49 |
+
def upsample(self, x):
|
| 50 |
+
cc = torch.cat(
|
| 51 |
+
[
|
| 52 |
+
x,
|
| 53 |
+
torch.zeros(
|
| 54 |
+
x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device
|
| 55 |
+
),
|
| 56 |
+
],
|
| 57 |
+
dim=3,
|
| 58 |
+
)
|
| 59 |
+
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
|
| 60 |
+
cc = cc.permute(0, 1, 3, 2)
|
| 61 |
+
cc = torch.cat(
|
| 62 |
+
[
|
| 63 |
+
cc,
|
| 64 |
+
torch.zeros(
|
| 65 |
+
x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device
|
| 66 |
+
),
|
| 67 |
+
],
|
| 68 |
+
dim=3,
|
| 69 |
+
)
|
| 70 |
+
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
|
| 71 |
+
x_up = cc.permute(0, 1, 3, 2)
|
| 72 |
+
return self.conv_gauss(x_up, 4 * self.kernel)
|
| 73 |
+
|
| 74 |
+
def conv_gauss(self, img, kernel):
|
| 75 |
+
# 对最后两个维度进行填充,(左右上下)
|
| 76 |
+
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode="reflect")
|
| 77 |
+
# 分组卷积
|
| 78 |
+
out = torch.nn.functional.conv2d(
|
| 79 |
+
img, kernel.to(img.device), groups=img.shape[1]
|
| 80 |
+
)
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
def pyramid_decom(self, img):
|
| 84 |
+
current = img
|
| 85 |
+
pyr = []
|
| 86 |
+
for _ in range(self.num_high):
|
| 87 |
+
filtered = self.conv_gauss(current, self.kernel)
|
| 88 |
+
down = self.downsample(filtered)
|
| 89 |
+
up = self.upsample(down)
|
| 90 |
+
if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:
|
| 91 |
+
up = nn.functional.interpolate(
|
| 92 |
+
up, size=(current.shape[2], current.shape[3])
|
| 93 |
+
)
|
| 94 |
+
diff = current - up
|
| 95 |
+
pyr.append(diff)
|
| 96 |
+
current = down
|
| 97 |
+
pyr.append(current)
|
| 98 |
+
return pyr
|
| 99 |
+
|
| 100 |
+
def pyramid_recons(self, pyr):
|
| 101 |
+
image = pyr[-1]
|
| 102 |
+
for level in reversed(pyr[:-1]):
|
| 103 |
+
up = self.upsample(image)
|
| 104 |
+
if up.shape[2] != level.shape[2] or up.shape[3] != level.shape[3]:
|
| 105 |
+
up = nn.functional.interpolate(
|
| 106 |
+
up, size=(level.shape[2], level.shape[3])
|
| 107 |
+
)
|
| 108 |
+
image = up + level
|
| 109 |
+
return image
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class TransHigh(nn.Module):
|
| 113 |
+
def __init__(self, num_residual_blocks, num_high=3):
|
| 114 |
+
super(TransHigh, self).__init__()
|
| 115 |
+
|
| 116 |
+
self.num_high = num_high
|
| 117 |
+
|
| 118 |
+
blocks = [nn.Conv2d(9, 64, 3, padding=1), nn.LeakyReLU()]
|
| 119 |
+
|
| 120 |
+
for _ in range(num_residual_blocks):
|
| 121 |
+
blocks += [ResidualBlock(64)]
|
| 122 |
+
|
| 123 |
+
blocks += [nn.Conv2d(64, 3, 3, padding=1)]
|
| 124 |
+
|
| 125 |
+
self.model = nn.Sequential(*blocks)
|
| 126 |
+
|
| 127 |
+
channels = 3
|
| 128 |
+
# Stage1
|
| 129 |
+
self.block1_1 = ConvLayer(
|
| 130 |
+
in_channels=channels,
|
| 131 |
+
out_channels=channels,
|
| 132 |
+
kernel_size=3,
|
| 133 |
+
stride=1,
|
| 134 |
+
dilation=2,
|
| 135 |
+
norm=None,
|
| 136 |
+
nonlinear="leakyrelu",
|
| 137 |
+
)
|
| 138 |
+
self.block1_2 = ConvLayer(
|
| 139 |
+
in_channels=channels,
|
| 140 |
+
out_channels=channels,
|
| 141 |
+
kernel_size=3,
|
| 142 |
+
stride=1,
|
| 143 |
+
dilation=4,
|
| 144 |
+
norm=None,
|
| 145 |
+
nonlinear="leakyrelu",
|
| 146 |
+
)
|
| 147 |
+
self.aggreation1_rgb = Aggreation(
|
| 148 |
+
in_channels=channels * 3, out_channels=channels
|
| 149 |
+
)
|
| 150 |
+
# Stage2
|
| 151 |
+
self.block2_1 = ConvLayer(
|
| 152 |
+
in_channels=channels,
|
| 153 |
+
out_channels=channels,
|
| 154 |
+
kernel_size=3,
|
| 155 |
+
stride=1,
|
| 156 |
+
dilation=8,
|
| 157 |
+
norm=None,
|
| 158 |
+
nonlinear="leakyrelu",
|
| 159 |
+
)
|
| 160 |
+
self.block2_2 = ConvLayer(
|
| 161 |
+
in_channels=channels,
|
| 162 |
+
out_channels=channels,
|
| 163 |
+
kernel_size=3,
|
| 164 |
+
stride=1,
|
| 165 |
+
dilation=16,
|
| 166 |
+
norm=None,
|
| 167 |
+
nonlinear="leakyrelu",
|
| 168 |
+
)
|
| 169 |
+
self.aggreation2_rgb = Aggreation(
|
| 170 |
+
in_channels=channels * 3, out_channels=channels
|
| 171 |
+
)
|
| 172 |
+
# Stage3
|
| 173 |
+
self.block3_1 = ConvLayer(
|
| 174 |
+
in_channels=channels,
|
| 175 |
+
out_channels=channels,
|
| 176 |
+
kernel_size=3,
|
| 177 |
+
stride=1,
|
| 178 |
+
dilation=32,
|
| 179 |
+
norm=None,
|
| 180 |
+
nonlinear="leakyrelu",
|
| 181 |
+
)
|
| 182 |
+
self.block3_2 = ConvLayer(
|
| 183 |
+
in_channels=channels,
|
| 184 |
+
out_channels=channels,
|
| 185 |
+
kernel_size=3,
|
| 186 |
+
stride=1,
|
| 187 |
+
dilation=64,
|
| 188 |
+
norm=None,
|
| 189 |
+
nonlinear="leakyrelu",
|
| 190 |
+
)
|
| 191 |
+
self.aggreation3_rgb = Aggreation(
|
| 192 |
+
in_channels=channels * 3, out_channels=channels
|
| 193 |
+
)
|
| 194 |
+
# self.block_3 = NAFNet(middle_blk_num=2, enc_blk_nums=[
|
| 195 |
+
# 1,1], dec_blk_nums=[1,1])
|
| 196 |
+
self.trans_mask_block_1 = nn.Sequential(
|
| 197 |
+
nn.Conv2d(3, 16, 1), nn.LeakyReLU(), nn.Conv2d(16, 3, 1)
|
| 198 |
+
)
|
| 199 |
+
self.trans_mask_block_2 = nn.Sequential(
|
| 200 |
+
nn.Conv2d(3, 16, 1), nn.LeakyReLU(), nn.Conv2d(16, 3, 1)
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# self.trans_mask_block = NAFNet(
|
| 204 |
+
# middle_blk_num=1, enc_blk_nums=[1], dec_blk_nums=[1])
|
| 205 |
+
# Stage3
|
| 206 |
+
self.spp_img = SPP(
|
| 207 |
+
in_channels=channels,
|
| 208 |
+
out_channels=channels,
|
| 209 |
+
num_layers=4,
|
| 210 |
+
interpolation_type="bicubic",
|
| 211 |
+
)
|
| 212 |
+
self.block4_1 = nn.Conv2d(
|
| 213 |
+
in_channels=channels, out_channels=3, kernel_size=1, stride=1
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def forward(self, x, pyr_original, fake_low):
|
| 217 |
+
pyr_result = [fake_low]
|
| 218 |
+
mask = self.model(x)
|
| 219 |
+
|
| 220 |
+
mask = nn.functional.interpolate(
|
| 221 |
+
mask, size=(pyr_original[-2].shape[2], pyr_original[-2].shape[3])
|
| 222 |
+
)
|
| 223 |
+
mask = self.trans_mask_block_1(mask)
|
| 224 |
+
result_highfreq = torch.mul(pyr_original[-2], mask) + pyr_original[-2]
|
| 225 |
+
|
| 226 |
+
# result_highfreq = self.block_3(result_highfreq)
|
| 227 |
+
out1_1 = self.block1_1(result_highfreq)
|
| 228 |
+
out1_2 = self.block1_2(out1_1)
|
| 229 |
+
agg1_rgb = self.aggreation1_rgb(
|
| 230 |
+
torch.cat((result_highfreq, out1_1, out1_2), dim=1)
|
| 231 |
+
)
|
| 232 |
+
pyr_result.append(agg1_rgb)
|
| 233 |
+
|
| 234 |
+
mask = nn.functional.interpolate(
|
| 235 |
+
mask, size=(pyr_original[-3].shape[2], pyr_original[-3].shape[3])
|
| 236 |
+
)
|
| 237 |
+
mask = self.trans_mask_block_2(mask)
|
| 238 |
+
result_highfreq = torch.mul(pyr_original[-3], mask) + pyr_original[-3]
|
| 239 |
+
|
| 240 |
+
# result_highfreq = self.block_3(result_highfreq)
|
| 241 |
+
out2_1 = self.block2_1(result_highfreq)
|
| 242 |
+
out2_2 = self.block2_2(out2_1)
|
| 243 |
+
agg2_rgb = self.aggreation2_rgb(
|
| 244 |
+
torch.cat((result_highfreq, out2_1, out2_2), dim=1)
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
out3_1 = self.block3_1(agg2_rgb)
|
| 248 |
+
out3_2 = self.block3_2(out3_1)
|
| 249 |
+
agg3_rgb = self.aggreation3_rgb(torch.cat((agg2_rgb, out3_1, out3_2), dim=1))
|
| 250 |
+
|
| 251 |
+
spp_rgb = self.spp_img(agg3_rgb)
|
| 252 |
+
out_rgb = self.block4_1(spp_rgb)
|
| 253 |
+
|
| 254 |
+
pyr_result.append(out_rgb)
|
| 255 |
+
pyr_result.reverse()
|
| 256 |
+
|
| 257 |
+
return pyr_result
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# Layer Norm
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def to_3d(x):
|
| 264 |
+
return rearrange(x, "b c h w -> b (h w) c")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def to_4d(x, h, w):
|
| 268 |
+
return rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class BiasFree_LayerNorm(nn.Module):
|
| 272 |
+
def __init__(self, normalized_shape):
|
| 273 |
+
super(BiasFree_LayerNorm, self).__init__()
|
| 274 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 275 |
+
normalized_shape = (normalized_shape,)
|
| 276 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 277 |
+
|
| 278 |
+
assert len(normalized_shape) == 1
|
| 279 |
+
|
| 280 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 281 |
+
self.normalized_shape = normalized_shape
|
| 282 |
+
|
| 283 |
+
def forward(self, x):
|
| 284 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 285 |
+
return x / torch.sqrt(sigma + 1e-5) * self.weight
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class WithBias_LayerNorm(nn.Module):
|
| 289 |
+
def __init__(self, normalized_shape):
|
| 290 |
+
super(WithBias_LayerNorm, self).__init__()
|
| 291 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 292 |
+
normalized_shape = (normalized_shape,)
|
| 293 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 294 |
+
|
| 295 |
+
assert len(normalized_shape) == 1
|
| 296 |
+
|
| 297 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 298 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 299 |
+
self.normalized_shape = normalized_shape
|
| 300 |
+
|
| 301 |
+
def forward(self, x):
|
| 302 |
+
mu = x.mean(-1, keepdim=True)
|
| 303 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 304 |
+
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class LayerNorm(nn.Module):
|
| 308 |
+
def __init__(self, dim, LayerNorm_type):
|
| 309 |
+
super(LayerNorm, self).__init__()
|
| 310 |
+
if LayerNorm_type == "BiasFree":
|
| 311 |
+
self.body = BiasFree_LayerNorm(dim)
|
| 312 |
+
else:
|
| 313 |
+
self.body = WithBias_LayerNorm(dim)
|
| 314 |
+
|
| 315 |
+
def forward(self, x):
|
| 316 |
+
h, w = x.shape[-2:]
|
| 317 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# Axis-based Multi-head Self-Attention
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class NextAttentionImplZ(nn.Module):
|
| 324 |
+
def __init__(self, num_dims, num_heads, bias) -> None:
|
| 325 |
+
super().__init__()
|
| 326 |
+
self.num_dims = num_dims
|
| 327 |
+
self.num_heads = num_heads
|
| 328 |
+
self.q1 = nn.Conv2d(num_dims, num_dims * 3, kernel_size=1, bias=bias)
|
| 329 |
+
self.q2 = nn.Conv2d(
|
| 330 |
+
num_dims * 3,
|
| 331 |
+
num_dims * 3,
|
| 332 |
+
kernel_size=3,
|
| 333 |
+
padding=1,
|
| 334 |
+
groups=num_dims * 3,
|
| 335 |
+
bias=bias,
|
| 336 |
+
)
|
| 337 |
+
self.q3 = nn.Conv2d(
|
| 338 |
+
num_dims * 3,
|
| 339 |
+
num_dims * 3,
|
| 340 |
+
kernel_size=3,
|
| 341 |
+
padding=1,
|
| 342 |
+
groups=num_dims * 3,
|
| 343 |
+
bias=bias,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
self.fac = nn.Parameter(torch.ones(1))
|
| 347 |
+
self.fin = nn.Conv2d(num_dims, num_dims, kernel_size=1, bias=bias)
|
| 348 |
+
return
|
| 349 |
+
|
| 350 |
+
def forward(self, x):
|
| 351 |
+
# x: [n, c, h, w]
|
| 352 |
+
n, c, h, w = x.size()
|
| 353 |
+
n_heads, dim_head = self.num_heads, c // self.num_heads
|
| 354 |
+
|
| 355 |
+
def reshape(x):
|
| 356 |
+
return einops.rearrange(
|
| 357 |
+
x, "n (nh dh) h w -> (n nh h) w dh", nh=n_heads, dh=dim_head
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
qkv = self.q3(self.q2(self.q1(x)))
|
| 361 |
+
q, k, v = map(reshape, qkv.chunk(3, dim=1))
|
| 362 |
+
q = F.normalize(q, dim=-1)
|
| 363 |
+
k = F.normalize(k, dim=-1)
|
| 364 |
+
|
| 365 |
+
# fac = dim_head ** -0.5
|
| 366 |
+
res = k.transpose(-2, -1)
|
| 367 |
+
res = torch.matmul(q, res) * self.fac
|
| 368 |
+
res = torch.softmax(res, dim=-1)
|
| 369 |
+
|
| 370 |
+
res = torch.matmul(res, v)
|
| 371 |
+
res = einops.rearrange(
|
| 372 |
+
res, "(n nh h) w dh -> n (nh dh) h w", nh=n_heads, dh=dim_head, n=n, h=h
|
| 373 |
+
)
|
| 374 |
+
res = self.fin(res)
|
| 375 |
+
|
| 376 |
+
return res
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# Axis-based Multi-head Self-Attention (row and col attention)
|
| 380 |
+
class NextAttentionZ(nn.Module):
|
| 381 |
+
def __init__(self, num_dims, num_heads=1, bias=True) -> None:
|
| 382 |
+
super().__init__()
|
| 383 |
+
assert num_dims % num_heads == 0
|
| 384 |
+
self.num_dims = num_dims
|
| 385 |
+
self.num_heads = num_heads
|
| 386 |
+
self.row_att = NextAttentionImplZ(num_dims, num_heads, bias)
|
| 387 |
+
self.col_att = NextAttentionImplZ(num_dims, num_heads, bias)
|
| 388 |
+
return
|
| 389 |
+
|
| 390 |
+
def forward(self, x: torch.Tensor):
|
| 391 |
+
assert len(x.size()) == 4
|
| 392 |
+
|
| 393 |
+
x = self.row_att(x)
|
| 394 |
+
x = x.transpose(-2, -1)
|
| 395 |
+
x = self.col_att(x)
|
| 396 |
+
x = x.transpose(-2, -1)
|
| 397 |
+
|
| 398 |
+
return x
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
# Dual Gated Feed-Forward Networ
|
| 402 |
+
class FeedForward(nn.Module):
|
| 403 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
| 404 |
+
super(FeedForward, self).__init__()
|
| 405 |
+
|
| 406 |
+
hidden_features = int(dim * ffn_expansion_factor)
|
| 407 |
+
|
| 408 |
+
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
| 409 |
+
|
| 410 |
+
self.dwconv = nn.Conv2d(
|
| 411 |
+
hidden_features * 2,
|
| 412 |
+
hidden_features * 2,
|
| 413 |
+
kernel_size=3,
|
| 414 |
+
stride=1,
|
| 415 |
+
padding=1,
|
| 416 |
+
groups=hidden_features * 2,
|
| 417 |
+
bias=bias,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 421 |
+
|
| 422 |
+
def forward(self, x):
|
| 423 |
+
x = self.project_in(x)
|
| 424 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
| 425 |
+
x = F.gelu(x2) * x1 + F.gelu(x1) * x2
|
| 426 |
+
x = self.project_out(x)
|
| 427 |
+
return x
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# Axis-based Transformer Block
|
| 431 |
+
class TransformerBlock(nn.Module):
|
| 432 |
+
def __init__(
|
| 433 |
+
self,
|
| 434 |
+
dim,
|
| 435 |
+
num_heads=1,
|
| 436 |
+
ffn_expansion_factor=2.66,
|
| 437 |
+
bias=True,
|
| 438 |
+
LayerNorm_type="WithBias",
|
| 439 |
+
):
|
| 440 |
+
super(TransformerBlock, self).__init__()
|
| 441 |
+
|
| 442 |
+
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
| 443 |
+
self.attn = NextAttentionZ(dim, num_heads)
|
| 444 |
+
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
| 445 |
+
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
| 446 |
+
|
| 447 |
+
def forward(self, x):
|
| 448 |
+
x = x + self.attn(self.norm1(x))
|
| 449 |
+
x = x + self.ffn(self.norm2(x))
|
| 450 |
+
return x
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
##########################################################################
|
| 454 |
+
# Overlapped image patch embedding with 3x3 Conv
|
| 455 |
+
class OverlapPatchEmbed(nn.Module):
|
| 456 |
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
| 457 |
+
super(OverlapPatchEmbed, self).__init__()
|
| 458 |
+
|
| 459 |
+
self.proj = nn.Conv2d(
|
| 460 |
+
in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
def forward(self, x):
|
| 464 |
+
x = self.proj(x)
|
| 465 |
+
|
| 466 |
+
return x
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
##########################################################################
|
| 470 |
+
# Resizing modules
|
| 471 |
+
class Downsample(nn.Module):
|
| 472 |
+
def __init__(self, n_feat):
|
| 473 |
+
super(Downsample, self).__init__()
|
| 474 |
+
|
| 475 |
+
self.body = nn.Sequential(
|
| 476 |
+
nn.Conv2d(
|
| 477 |
+
n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False
|
| 478 |
+
),
|
| 479 |
+
nn.PixelUnshuffle(2),
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
def forward(self, x):
|
| 483 |
+
return self.body(x)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class Upsample(nn.Module):
|
| 487 |
+
def __init__(self, n_feat):
|
| 488 |
+
super(Upsample, self).__init__()
|
| 489 |
+
|
| 490 |
+
self.body = nn.Sequential(
|
| 491 |
+
nn.Conv2d(
|
| 492 |
+
n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False
|
| 493 |
+
),
|
| 494 |
+
nn.PixelShuffle(2),
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
def forward(self, x):
|
| 498 |
+
return self.body(x)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
# Cross-layer Attention Fusion Block
|
| 502 |
+
class LAM_Module_v2(nn.Module):
|
| 503 |
+
"""Layer attention module"""
|
| 504 |
+
|
| 505 |
+
def __init__(self, in_dim, bias=True):
|
| 506 |
+
super(LAM_Module_v2, self).__init__()
|
| 507 |
+
self.chanel_in = in_dim
|
| 508 |
+
|
| 509 |
+
self.temperature = nn.Parameter(torch.ones(1))
|
| 510 |
+
|
| 511 |
+
self.qkv = nn.Conv2d(
|
| 512 |
+
self.chanel_in, self.chanel_in * 3, kernel_size=1, bias=bias
|
| 513 |
+
)
|
| 514 |
+
self.qkv_dwconv = nn.Conv2d(
|
| 515 |
+
self.chanel_in * 3,
|
| 516 |
+
self.chanel_in * 3,
|
| 517 |
+
kernel_size=3,
|
| 518 |
+
stride=1,
|
| 519 |
+
padding=1,
|
| 520 |
+
groups=self.chanel_in * 3,
|
| 521 |
+
bias=bias,
|
| 522 |
+
)
|
| 523 |
+
self.project_out = nn.Conv2d(
|
| 524 |
+
self.chanel_in, self.chanel_in, kernel_size=1, bias=bias
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
def forward(self, x):
|
| 528 |
+
"""
|
| 529 |
+
inputs :
|
| 530 |
+
x : input feature maps( B X N X C X H X W)
|
| 531 |
+
returns :
|
| 532 |
+
out : attention value + input feature
|
| 533 |
+
attention: B X N X N
|
| 534 |
+
"""
|
| 535 |
+
m_batchsize, N, C, height, width = x.size()
|
| 536 |
+
|
| 537 |
+
x_input = x.view(m_batchsize, N * C, height, width)
|
| 538 |
+
qkv = self.qkv_dwconv(self.qkv(x_input))
|
| 539 |
+
q, k, v = qkv.chunk(3, dim=1)
|
| 540 |
+
q = q.view(m_batchsize, N, -1)
|
| 541 |
+
k = k.view(m_batchsize, N, -1)
|
| 542 |
+
v = v.view(m_batchsize, N, -1)
|
| 543 |
+
|
| 544 |
+
q = torch.nn.functional.normalize(q, dim=-1)
|
| 545 |
+
k = torch.nn.functional.normalize(k, dim=-1)
|
| 546 |
+
|
| 547 |
+
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
| 548 |
+
attn = attn.softmax(dim=-1)
|
| 549 |
+
|
| 550 |
+
out_1 = attn @ v
|
| 551 |
+
out_1 = out_1.view(m_batchsize, -1, height, width)
|
| 552 |
+
|
| 553 |
+
out_1 = self.project_out(out_1)
|
| 554 |
+
out_1 = out_1.view(m_batchsize, N, C, height, width)
|
| 555 |
+
|
| 556 |
+
out = out_1 + x
|
| 557 |
+
out = out.view(m_batchsize, -1, height, width)
|
| 558 |
+
return out
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
##########################################################################
|
| 562 |
+
# ---------- LLFormer -----------------------
|
| 563 |
+
class Backbone(nn.Module):
|
| 564 |
+
def __init__(
|
| 565 |
+
self,
|
| 566 |
+
inp_channels=3,
|
| 567 |
+
out_channels=3,
|
| 568 |
+
dim=3,
|
| 569 |
+
num_blocks=[1, 2, 4, 8],
|
| 570 |
+
num_refinement_blocks=1,
|
| 571 |
+
heads=[1, 2, 4, 8],
|
| 572 |
+
ffn_expansion_factor=2.66,
|
| 573 |
+
bias=False,
|
| 574 |
+
LayerNorm_type="WithBias",
|
| 575 |
+
attention=True,
|
| 576 |
+
):
|
| 577 |
+
super(Backbone, self).__init__()
|
| 578 |
+
|
| 579 |
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
| 580 |
+
|
| 581 |
+
self.encoder_1 = nn.Sequential(
|
| 582 |
+
*[
|
| 583 |
+
TransformerBlock(
|
| 584 |
+
dim=dim,
|
| 585 |
+
num_heads=heads[0],
|
| 586 |
+
ffn_expansion_factor=ffn_expansion_factor,
|
| 587 |
+
bias=bias,
|
| 588 |
+
LayerNorm_type=LayerNorm_type,
|
| 589 |
+
)
|
| 590 |
+
for _ in range(num_blocks[0])
|
| 591 |
+
]
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
self.encoder_2 = nn.Sequential(
|
| 595 |
+
*[
|
| 596 |
+
TransformerBlock(
|
| 597 |
+
dim=int(dim),
|
| 598 |
+
num_heads=heads[0],
|
| 599 |
+
ffn_expansion_factor=ffn_expansion_factor,
|
| 600 |
+
bias=bias,
|
| 601 |
+
LayerNorm_type=LayerNorm_type,
|
| 602 |
+
)
|
| 603 |
+
for _ in range(num_blocks[0])
|
| 604 |
+
]
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
self.encoder_3 = nn.Sequential(
|
| 608 |
+
*[
|
| 609 |
+
TransformerBlock(
|
| 610 |
+
dim=int(dim),
|
| 611 |
+
num_heads=heads[0],
|
| 612 |
+
ffn_expansion_factor=ffn_expansion_factor,
|
| 613 |
+
bias=bias,
|
| 614 |
+
LayerNorm_type=LayerNorm_type,
|
| 615 |
+
)
|
| 616 |
+
for _ in range(num_blocks[0])
|
| 617 |
+
]
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
self.layer_fussion = LAM_Module_v2(in_dim=int(dim * 3))
|
| 621 |
+
self.conv_fuss = nn.Conv2d(int(dim * 3), int(dim), kernel_size=1, bias=bias)
|
| 622 |
+
|
| 623 |
+
# self.latent = nn.Sequential(*[
|
| 624 |
+
# TransformerBlock(dim=int(dim), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
|
| 625 |
+
# LayerNorm_type=LayerNorm_type) for _ in range(num_blocks[0])])
|
| 626 |
+
|
| 627 |
+
# self.trans_low = NAFNet()
|
| 628 |
+
|
| 629 |
+
# self.coefficient_1_0 = nn.Parameter(torch.ones(
|
| 630 |
+
# (2, int(int(dim)))), requires_grad=attention)
|
| 631 |
+
|
| 632 |
+
self.latent_1 = nn.Sequential(
|
| 633 |
+
*[
|
| 634 |
+
TransformerBlock(
|
| 635 |
+
dim=int(dim),
|
| 636 |
+
num_heads=heads[0],
|
| 637 |
+
ffn_expansion_factor=ffn_expansion_factor,
|
| 638 |
+
bias=bias,
|
| 639 |
+
LayerNorm_type=LayerNorm_type,
|
| 640 |
+
)
|
| 641 |
+
for _ in range(num_blocks[0])
|
| 642 |
+
]
|
| 643 |
+
)
|
| 644 |
+
"""
|
| 645 |
+
self.latent_2 = nn.Sequential(*[
|
| 646 |
+
TransformerBlock(dim=int(dim), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
|
| 647 |
+
LayerNorm_type=LayerNorm_type) for _ in range(num_blocks[0])])
|
| 648 |
+
"""
|
| 649 |
+
self.trans_low_1 = NAFNet(
|
| 650 |
+
middle_blk_num=10, enc_blk_nums=[1, 2, 4], dec_blk_nums=[4, 2, 1]
|
| 651 |
+
)
|
| 652 |
+
# self.trans_low_2 = NAFNet()
|
| 653 |
+
|
| 654 |
+
self.coefficient_1_0 = nn.Parameter(
|
| 655 |
+
torch.ones((2, int(int(dim)))), requires_grad=attention
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
# self.coefficient_2_0 = nn.Parameter(torch.ones(
|
| 659 |
+
# (2, int(int(dim)))), requires_grad=attention)
|
| 660 |
+
|
| 661 |
+
self.refinement_1 = nn.Sequential(
|
| 662 |
+
*[
|
| 663 |
+
TransformerBlock(
|
| 664 |
+
dim=int(dim),
|
| 665 |
+
num_heads=heads[0],
|
| 666 |
+
ffn_expansion_factor=ffn_expansion_factor,
|
| 667 |
+
bias=bias,
|
| 668 |
+
LayerNorm_type=LayerNorm_type,
|
| 669 |
+
)
|
| 670 |
+
for _ in range(num_refinement_blocks)
|
| 671 |
+
]
|
| 672 |
+
)
|
| 673 |
+
self.refinement_2 = nn.Sequential(
|
| 674 |
+
*[
|
| 675 |
+
TransformerBlock(
|
| 676 |
+
dim=int(dim),
|
| 677 |
+
num_heads=heads[0],
|
| 678 |
+
ffn_expansion_factor=ffn_expansion_factor,
|
| 679 |
+
bias=bias,
|
| 680 |
+
LayerNorm_type=LayerNorm_type,
|
| 681 |
+
)
|
| 682 |
+
for _ in range(num_refinement_blocks)
|
| 683 |
+
]
|
| 684 |
+
)
|
| 685 |
+
self.refinement_3 = nn.Sequential(
|
| 686 |
+
*[
|
| 687 |
+
TransformerBlock(
|
| 688 |
+
dim=int(dim),
|
| 689 |
+
num_heads=heads[0],
|
| 690 |
+
ffn_expansion_factor=ffn_expansion_factor,
|
| 691 |
+
bias=bias,
|
| 692 |
+
LayerNorm_type=LayerNorm_type,
|
| 693 |
+
)
|
| 694 |
+
for _ in range(num_refinement_blocks)
|
| 695 |
+
]
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
self.layer_fussion_2 = LAM_Module_v2(in_dim=int(dim * 3))
|
| 699 |
+
self.conv_fuss_2 = nn.Conv2d(int(dim * 3), int(dim), kernel_size=1, bias=bias)
|
| 700 |
+
|
| 701 |
+
self.output = nn.Conv2d(
|
| 702 |
+
int(dim), out_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
def forward(self, inp):
|
| 706 |
+
inp_enc_encoder1 = self.patch_embed(inp)
|
| 707 |
+
out_enc_encoder1 = self.encoder_1(inp_enc_encoder1)
|
| 708 |
+
out_enc_encoder2 = self.encoder_2(out_enc_encoder1)
|
| 709 |
+
out_enc_encoder3 = self.encoder_3(out_enc_encoder2)
|
| 710 |
+
|
| 711 |
+
inp_fusion_123 = torch.cat(
|
| 712 |
+
[
|
| 713 |
+
out_enc_encoder1.unsqueeze(1),
|
| 714 |
+
out_enc_encoder2.unsqueeze(1),
|
| 715 |
+
out_enc_encoder3.unsqueeze(1),
|
| 716 |
+
],
|
| 717 |
+
dim=1,
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
out_fusion_123 = self.layer_fussion(inp_fusion_123)
|
| 721 |
+
out_fusion_123 = self.conv_fuss(out_fusion_123)
|
| 722 |
+
|
| 723 |
+
# out_enc = self.trans_low(out_fusion_123)
|
| 724 |
+
|
| 725 |
+
# out_fusion_123 = self.latent(out_fusion_123)
|
| 726 |
+
|
| 727 |
+
# out = self.coefficient_1_0[0, :][None, :, None, None] * out_fusion_123 + self.coefficient_1_0[1, :][None, :,None, None] * out_enc
|
| 728 |
+
|
| 729 |
+
out_enc_1 = self.trans_low_1(out_fusion_123)
|
| 730 |
+
|
| 731 |
+
out_fusion_123_1 = self.latent_1(out_fusion_123)
|
| 732 |
+
|
| 733 |
+
out = (
|
| 734 |
+
self.coefficient_1_0[0, :][None, :, None, None] * out_fusion_123_1
|
| 735 |
+
+ self.coefficient_1_0[1, :][None, :, None, None] * out_enc_1
|
| 736 |
+
)
|
| 737 |
+
# out_enc_2 = self.trans_low_2(out)
|
| 738 |
+
|
| 739 |
+
# out_fusion_123_2 = self.latent_2(out)
|
| 740 |
+
|
| 741 |
+
# out = self.coefficient_2_0[0, :][None, :, None, None] * out_fusion_123_2 + self.coefficient_2_0[1, :][None, :,None, None] * out_enc_2
|
| 742 |
+
out_1 = self.refinement_1(out)
|
| 743 |
+
out_2 = self.refinement_2(out_1)
|
| 744 |
+
out_3 = self.refinement_3(out_2)
|
| 745 |
+
|
| 746 |
+
inp_fusion = torch.cat(
|
| 747 |
+
[out_1.unsqueeze(1), out_2.unsqueeze(1), out_3.unsqueeze(1)], dim=1
|
| 748 |
+
)
|
| 749 |
+
out_fusion_123 = self.layer_fussion_2(inp_fusion)
|
| 750 |
+
out = self.conv_fuss_2(out_fusion_123)
|
| 751 |
+
result = self.output(out)
|
| 752 |
+
|
| 753 |
+
return result
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
class Model(nn.Module):
|
| 757 |
+
def __init__(self, depth=2):
|
| 758 |
+
super(Model, self).__init__()
|
| 759 |
+
self.backbone = Backbone()
|
| 760 |
+
self.lap_pyramid = LapPyramidConv(depth)
|
| 761 |
+
self.trans_high = TransHigh(3, num_high=depth)
|
| 762 |
+
|
| 763 |
+
def forward(self, inp):
|
| 764 |
+
pyr_inp = self.lap_pyramid.pyramid_decom(img=inp)
|
| 765 |
+
out_low = self.backbone(pyr_inp[-1])
|
| 766 |
+
|
| 767 |
+
inp_up = nn.functional.interpolate(
|
| 768 |
+
pyr_inp[-1], size=(pyr_inp[-2].shape[2], pyr_inp[-2].shape[3])
|
| 769 |
+
)
|
| 770 |
+
out_up = nn.functional.interpolate(
|
| 771 |
+
out_low, size=(pyr_inp[-2].shape[2], pyr_inp[-2].shape[3])
|
| 772 |
+
)
|
| 773 |
+
high_with_low = torch.cat([pyr_inp[-2], inp_up, out_up], 1)
|
| 774 |
+
|
| 775 |
+
pyr_inp_trans = self.trans_high(high_with_low, pyr_inp, out_low)
|
| 776 |
+
|
| 777 |
+
result = self.lap_pyramid.pyramid_recons(pyr_inp_trans)
|
| 778 |
+
|
| 779 |
+
return result
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
if __name__ == "__main__":
|
| 783 |
+
tensor = torch.randn(1, 3, 1024, 1024).cuda()
|
| 784 |
+
model = Model().cuda()
|
| 785 |
+
output = model(tensor)
|
| 786 |
+
print(output.shape)
|
app.py
CHANGED
|
@@ -2,13 +2,30 @@ import numpy as np
|
|
| 2 |
import gradio as gr
|
| 3 |
import numpy as np
|
| 4 |
import random
|
| 5 |
-
|
| 6 |
-
# import torch
|
| 7 |
import spaces
|
| 8 |
import os
|
| 9 |
import base64
|
| 10 |
import json
|
|
|
|
| 11 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
intro = """
|
| 14 |
<div style="text-align:center">
|
|
@@ -142,19 +159,44 @@ def encode_image(pil_image):
|
|
| 142 |
# raise Exception(f"Failed to post: {response}")
|
| 143 |
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
#
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
# --- UI Constants and Helpers ---
|
| 160 |
MAX_SEED = np.iinfo(np.int32).max
|
|
@@ -164,47 +206,137 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
| 164 |
@spaces.GPU(duration=120)
|
| 165 |
def infer(
|
| 166 |
image,
|
| 167 |
-
seed=
|
| 168 |
-
|
| 169 |
num_inference_steps=50,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
progress=gr.Progress(track_tqdm=True),
|
| 171 |
):
|
| 172 |
"""
|
| 173 |
-
Generates an image
|
| 174 |
"""
|
| 175 |
-
#
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
|
| 179 |
-
# seed = 42
|
| 180 |
-
# seed = random.randint(0, MAX_SEED)
|
| 181 |
|
| 182 |
-
|
| 183 |
-
# generator = torch.Generator(device=device).manual_seed(seed)
|
| 184 |
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
)
|
| 190 |
-
# if rewrite_prompt:
|
| 191 |
-
# # prompt = polish_prompt(prompt, image)
|
| 192 |
-
# print(f"Rewritten Prompt: {prompt}")
|
| 193 |
|
| 194 |
-
#
|
| 195 |
-
|
| 196 |
-
#
|
| 197 |
-
#
|
| 198 |
-
|
| 199 |
-
#
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
-
|
| 206 |
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
|
| 210 |
# --- Examples and UI Layout ---
|
|
@@ -243,6 +375,20 @@ with gr.Blocks(css=css) as demo:
|
|
| 243 |
value=42,
|
| 244 |
)
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
# randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 247 |
|
| 248 |
with gr.Row():
|
|
@@ -306,8 +452,12 @@ with gr.Blocks(css=css) as demo:
|
|
| 306 |
seed,
|
| 307 |
true_guidance_scale,
|
| 308 |
num_inference_steps,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
],
|
| 310 |
-
outputs=[outpainted_result, flarefree_result
|
| 311 |
)
|
| 312 |
|
| 313 |
if __name__ == "__main__":
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import numpy as np
|
| 4 |
import random
|
| 5 |
+
import torch
|
|
|
|
| 6 |
import spaces
|
| 7 |
import os
|
| 8 |
import base64
|
| 9 |
import json
|
| 10 |
+
import torchvision
|
| 11 |
from PIL import Image
|
| 12 |
+
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
| 13 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
| 14 |
+
|
| 15 |
+
from src.pipelines.pipeline_stable_diffusion_outpaint import OutpaintPipeline
|
| 16 |
+
from src.pipelines.pipeline_controlnet_outpaint import ControlNetOutpaintPipeline
|
| 17 |
+
from src.schedulers.scheduling_pndm import CustomScheduler
|
| 18 |
+
from src.models.unet import U_Net
|
| 19 |
+
from src.models.light_source_regressor import LightSourceRegressor
|
| 20 |
+
from utils.dataset import HFCustomImageLoader
|
| 21 |
+
from utils.utils import (
|
| 22 |
+
blend_with_alpha,
|
| 23 |
+
load_mfdnet_checkpoint,
|
| 24 |
+
predict_flare_from_6_channel,
|
| 25 |
+
predict_flare_from_3_channel,
|
| 26 |
+
blend_light_source,
|
| 27 |
+
)
|
| 28 |
+
from SIFR_models.flare7kpp.model import Uformer
|
| 29 |
|
| 30 |
intro = """
|
| 31 |
<div style="text-align:center">
|
|
|
|
| 159 |
# raise Exception(f"Failed to post: {response}")
|
| 160 |
|
| 161 |
|
| 162 |
+
## --- Model Loading --- ##
|
| 163 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 164 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 165 |
+
print(f"Using device: {device}")
|
| 166 |
+
|
| 167 |
+
# controlnet
|
| 168 |
+
controlnet = ControlNetModel.from_pretrained(
|
| 169 |
+
"RayTsai-030/LightsOut-controlnet", torch_dtype=dtype
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# outpainter
|
| 173 |
+
pipe = ControlNetOutpaintPipeline.from_pretrained(
|
| 174 |
+
"stabilityai/stable-diffusion-2-inpainting", controlnet=controlnet, torch_dtype=dtype
|
| 175 |
+
).to(device)
|
| 176 |
+
pipe.scheduler = CustomScheduler.from_config(pipe.scheduler.config)
|
| 177 |
+
pipe.unet.load_attn_procs("./weights/light_outpaint_lora", use_safetensors=True)
|
| 178 |
+
|
| 179 |
+
# blip
|
| 180 |
+
processor = Blip2Processor.from_pretrained(
|
| 181 |
+
"Salesforce/blip2-opt-2.7b", revision="51572668da0eb669e01a189dc22abe6088589a24"
|
| 182 |
+
)
|
| 183 |
+
blip2 = Blip2ForConditionalGeneration.from_pretrained(
|
| 184 |
+
"Salesforce/blip2-opt-2.7b",
|
| 185 |
+
torch_dtype=dtype,
|
| 186 |
+
revision="51572668da0eb669e01a189dc22abe6088589a24",
|
| 187 |
+
)
|
| 188 |
+
blip2 = blip2.to(device)
|
| 189 |
+
|
| 190 |
+
# light regressor
|
| 191 |
+
lsr_module = LightSourceRegressor()
|
| 192 |
+
ckpt = torch.load("./weights/light_regress/model.pth")
|
| 193 |
+
lsr_module.load_state_dict(ckpt["model"])
|
| 194 |
+
lsr_module.to(device)
|
| 195 |
+
lsr_module.eval()
|
| 196 |
+
|
| 197 |
+
# SIFR model
|
| 198 |
+
sifr_model = Uformer(img_size=512, img_ch=3, output_ch=6).to(device)
|
| 199 |
+
sifr_model.load_state_dict(torch.load("./weights/net_g_last.pth"))
|
| 200 |
|
| 201 |
# --- UI Constants and Helpers ---
|
| 202 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
| 206 |
@spaces.GPU(duration=120)
|
| 207 |
def infer(
|
| 208 |
image,
|
| 209 |
+
seed=42,
|
| 210 |
+
cfg=7.5,
|
| 211 |
num_inference_steps=50,
|
| 212 |
+
left_outpaint=64,
|
| 213 |
+
right_outpaint=64,
|
| 214 |
+
up_outpaint=64,
|
| 215 |
+
down_outpaint=64,
|
| 216 |
progress=gr.Progress(track_tqdm=True),
|
| 217 |
):
|
| 218 |
"""
|
| 219 |
+
Generates an image
|
| 220 |
"""
|
| 221 |
+
# dataset
|
| 222 |
+
dataset = HFCustomImageLoader(image, left_outpaint, right_outpaint, up_outpaint, down_outpaint)
|
| 223 |
+
data = dataset[0]
|
| 224 |
+
|
| 225 |
+
# generator
|
| 226 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 227 |
+
|
| 228 |
+
# transformation
|
| 229 |
+
transform = torchvision.transforms.Compose(
|
| 230 |
+
[
|
| 231 |
+
torchvision.transforms.ToTensor(),
|
| 232 |
+
torchvision.transforms.Normalize(mean=[0.5], std=[0.5]),
|
| 233 |
+
]
|
| 234 |
+
)
|
| 235 |
+
sifr_transform = torchvision.transforms.Compose(
|
| 236 |
+
[
|
| 237 |
+
torchvision.transforms.ToTensor(),
|
| 238 |
+
torchvision.transforms.Resize((512, 512)),
|
| 239 |
+
]
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
threshold = 0.5
|
| 243 |
+
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
input_img = data["input_img"]
|
| 246 |
+
|
| 247 |
+
input_img = transform(input_img).unsqueeze(0).to(device)
|
| 248 |
|
| 249 |
+
pred_mask = lsr_module.forward_render(input_img)
|
|
|
|
|
|
|
| 250 |
|
| 251 |
+
pred_mask = (pred_mask > threshold).float()
|
|
|
|
| 252 |
|
| 253 |
+
if pred_mask.device != "cpu":
|
| 254 |
+
pred_mask = pred_mask.cpu()
|
| 255 |
+
pred_mask = pred_mask.numpy()
|
| 256 |
+
|
| 257 |
+
data["control_img"] = Image.fromarray(
|
| 258 |
+
(pred_mask[0, 0] * 255).astype(np.uint8)
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# print("Finish light source detection...")
|
| 262 |
+
|
| 263 |
+
# prepare text prompt
|
| 264 |
+
inputs = processor(data["blip_img"], return_tensors="pt").to(
|
| 265 |
+
device=device, dtype=dtype
|
| 266 |
+
)
|
| 267 |
+
generate_id = blip2.generate(**inputs, max_new_tokens=20)
|
| 268 |
+
generated_text = processor.batch_decode(generate_id, skip_special_tokens=True)[
|
| 269 |
+
0
|
| 270 |
+
].strip()
|
| 271 |
+
|
| 272 |
+
generated_text += (
|
| 273 |
+
", dynamic lighting, intense light source, prominent lens flare, best quality, high resolution, masterpiece, intricate details"
|
| 274 |
+
# ", full light sources with lens flare, best quality, high resolution"
|
| 275 |
)
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
+
# print(f"Generated text prompt: {generated_text}")
|
| 278 |
+
|
| 279 |
+
# Blur mask
|
| 280 |
+
# data["mask_img"] = data["mask_img"].filter(ImageFilter.GaussianBlur(15))
|
| 281 |
+
|
| 282 |
+
# denoise
|
| 283 |
+
outpaint_result = pipe(
|
| 284 |
+
prompt=generated_text,
|
| 285 |
+
negative_prompt="NSFW, (word:1.5), watermark, blurry, missing body, amputation, mutilation",
|
| 286 |
+
image=data["input_img"],
|
| 287 |
+
mask_image=data["mask_img"],
|
| 288 |
+
control_image=data["control_img"],
|
| 289 |
+
num_inference_steps=num_inference_steps,
|
| 290 |
+
guidance_scale=cfg,
|
| 291 |
+
generator=generator,
|
| 292 |
+
repeat_time=4,
|
| 293 |
+
).images[0]
|
| 294 |
+
|
| 295 |
+
# save result
|
| 296 |
+
outpaint_result = np.array(outpaint_result)
|
| 297 |
+
input_img = np.array(data["input_img"])
|
| 298 |
+
box = data["box"]
|
| 299 |
+
|
| 300 |
+
input_img2 = outpaint_result.copy()
|
| 301 |
+
input_img2[box[2] : box[3] + 1, box[0] : box[1] + 1] = input_img[
|
| 302 |
+
box[2] : box[3] + 1, box[0] : box[1] + 1
|
| 303 |
+
]
|
| 304 |
|
| 305 |
+
outpaint_result = blend_with_alpha(outpaint_result, input_img2, box, blur_size=31)
|
| 306 |
|
| 307 |
+
outpaint_result = Image.fromarray(outpaint_result.astype(np.uint8))
|
| 308 |
+
|
| 309 |
+
# print("Finish outpainting...")
|
| 310 |
+
|
| 311 |
+
# flare removal
|
| 312 |
+
img = sifr_transform(outpaint_result).unsqueeze(0).cuda()
|
| 313 |
+
|
| 314 |
+
with torch.no_grad():
|
| 315 |
+
output_img = sifr_model(img)
|
| 316 |
+
|
| 317 |
+
gamma = torch.Tensor([2.2])
|
| 318 |
+
|
| 319 |
+
# flare7k++
|
| 320 |
+
deflare_result, _, _ = predict_flare_from_6_channel(output_img, gamma)
|
| 321 |
+
|
| 322 |
+
# # mfdnet
|
| 323 |
+
# flare_mask = torch.zeros_like(img)
|
| 324 |
+
# deflare_img, _ = predict_flare_from_3_channel(
|
| 325 |
+
# output_img, flare_mask, output_img, img, img, gamma
|
| 326 |
+
# )
|
| 327 |
+
# deflare_img = blend_light_source(img, deflare_img, 0.999)
|
| 328 |
+
|
| 329 |
+
if deflare_result.device != "cpu":
|
| 330 |
+
deflare_result = deflare_result.cpu()
|
| 331 |
+
deflare_result = deflare_result.squeeze(0).permute(1, 2, 0).numpy()
|
| 332 |
+
deflare_result = np.clip(deflare_result, 0.0, 1.0)
|
| 333 |
+
deflare_result = (deflare_result * 255).astype(np.uint8)
|
| 334 |
+
deflare_result = deflare_result[box[2] : box[3] + 1, box[0] : box[1] + 1, :]
|
| 335 |
+
deflare_result = Image.fromarray(deflare_result).resize((512, 512), Image.LANCZOS)
|
| 336 |
+
|
| 337 |
+
# print("Finish flare removal...")
|
| 338 |
+
|
| 339 |
+
return outpaint_result, deflare_result
|
| 340 |
|
| 341 |
|
| 342 |
# --- Examples and UI Layout ---
|
|
|
|
| 375 |
value=42,
|
| 376 |
)
|
| 377 |
|
| 378 |
+
with gr.Column():
|
| 379 |
+
left_outpaint = gr.Slider(
|
| 380 |
+
label="Left outpaint (px)", minimum=0, maximum=128, step=1, value=64
|
| 381 |
+
)
|
| 382 |
+
right_outpaint = gr.Slider(
|
| 383 |
+
label="Right outpaint (px)", minimum=0, maximum=128, step=1, value=64
|
| 384 |
+
)
|
| 385 |
+
up_outpaint = gr.Slider(
|
| 386 |
+
label="Up outpaint (px)", minimum=0, maximum=128, step=1, value=64
|
| 387 |
+
)
|
| 388 |
+
down_outpaint = gr.Slider(
|
| 389 |
+
label="Down outpaint (px)", minimum=0, maximum=128, step=1, value=64
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
# randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 393 |
|
| 394 |
with gr.Row():
|
|
|
|
| 452 |
seed,
|
| 453 |
true_guidance_scale,
|
| 454 |
num_inference_steps,
|
| 455 |
+
left_outpaint,
|
| 456 |
+
right_outpaint,
|
| 457 |
+
up_outpaint,
|
| 458 |
+
down_outpaint,
|
| 459 |
],
|
| 460 |
+
outputs=[outpainted_result, flarefree_result],
|
| 461 |
)
|
| 462 |
|
| 463 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
|
@@ -1,2 +1,15 @@
|
|
| 1 |
-
gradio
|
| 2 |
-
pydantic==2.10.6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.1
|
| 2 |
+
pydantic==2.10.6
|
| 3 |
+
accelerate==0.21.0
|
| 4 |
+
diffusers==0.23.0
|
| 5 |
+
einops==0.8.0
|
| 6 |
+
huggingface-hub==0.25.2
|
| 7 |
+
imageio==2.36.0
|
| 8 |
+
numpy==1.24.1
|
| 9 |
+
opencv-python==4.10.0.84
|
| 10 |
+
scikit-image==0.24.0
|
| 11 |
+
timm==1.0.11
|
| 12 |
+
transformers==4.36.0
|
| 13 |
+
xformers==0.0.20
|
| 14 |
+
spaces
|
| 15 |
+
pillow
|
src/models/__pycache__/light_source_regressor.cpython-39.pyc
ADDED
|
Binary file (3.41 kB). View file
|
|
|
src/models/__pycache__/unet.cpython-39.pyc
ADDED
|
Binary file (3.75 kB). View file
|
|
|
src/models/light_source_regressor.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import init
|
| 4 |
+
from torchvision.models import resnet34, resnet50
|
| 5 |
+
import torchvision.models.vision_transformer as vit
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class LightSourceRegressor(nn.Module):
|
| 9 |
+
def __init__(self, num_lights=4, alpha=2.0, beta=8.0, **kwargs):
|
| 10 |
+
super(LightSourceRegressor, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.num_lights = num_lights
|
| 13 |
+
self.alpha = alpha
|
| 14 |
+
self.beta = beta
|
| 15 |
+
|
| 16 |
+
self.model = resnet34(pretrained=True)
|
| 17 |
+
# self.model = resnet50(pretrained=True)
|
| 18 |
+
# self.model = vit.vit_b_16(pretrained=True)
|
| 19 |
+
self.init_resnet()
|
| 20 |
+
# self.init_vit()
|
| 21 |
+
|
| 22 |
+
self.xyr_mlp = nn.Sequential(
|
| 23 |
+
nn.Linear(self.last_dim, 3 * self.num_lights),
|
| 24 |
+
)
|
| 25 |
+
self.p_mlp = nn.Sequential(
|
| 26 |
+
nn.Linear(self.last_dim, self.num_lights),
|
| 27 |
+
nn.Sigmoid(), # ensure p is in [0, 1]
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def init_resnet(self):
|
| 31 |
+
self.last_dim = self.model.fc.in_features
|
| 32 |
+
self.model.fc = nn.Identity()
|
| 33 |
+
|
| 34 |
+
def init_vit(self):
|
| 35 |
+
self.model.image_size = 512
|
| 36 |
+
old_pos_embed = self.model.encoder.pos_embedding
|
| 37 |
+
num_patches_old = (224 // 16) ** 2
|
| 38 |
+
num_patches_new = (512 // 16) ** 2
|
| 39 |
+
|
| 40 |
+
if num_patches_new != num_patches_old:
|
| 41 |
+
old_pos_embed = old_pos_embed[:, 1:]
|
| 42 |
+
old_pos_embed = nn.functional.interpolate(
|
| 43 |
+
old_pos_embed.permute(0, 2, 1), size=(num_patches_new,), mode="linear"
|
| 44 |
+
)
|
| 45 |
+
old_pos_embed = old_pos_embed.permute(0, 2, 1)
|
| 46 |
+
|
| 47 |
+
# new positional embedding
|
| 48 |
+
self.model.encoder.pos_embedding = nn.Parameter(
|
| 49 |
+
torch.cat(
|
| 50 |
+
[self.model.encoder.pos_embedding[:, :1], old_pos_embed], dim=1
|
| 51 |
+
)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# num_classes = 4 * self.num_lights # x, y, r, p
|
| 55 |
+
# self.model.heads.head = nn.Linear(self.model.hidden_dim, num_classes)
|
| 56 |
+
|
| 57 |
+
# remove the head
|
| 58 |
+
self.last_dim = self.model.hidden_dim
|
| 59 |
+
self.model.heads.head = nn.Identity()
|
| 60 |
+
|
| 61 |
+
def forward(self, x, height=512, width=512, smoothness=0.1, merge=False):
|
| 62 |
+
_x = self.model(x) # [B, last_dim]
|
| 63 |
+
|
| 64 |
+
_xyr = self.xyr_mlp(_x)
|
| 65 |
+
_xyr = _xyr.view(-1, self.num_lights, 3)
|
| 66 |
+
|
| 67 |
+
_p = self.p_mlp(_x)
|
| 68 |
+
_p = _p.view(-1, self.num_lights)
|
| 69 |
+
|
| 70 |
+
output = torch.cat([_xyr, _p.unsqueeze(-1)], dim=-1)
|
| 71 |
+
|
| 72 |
+
return output
|
| 73 |
+
|
| 74 |
+
def forward_render(self, x, height=512, width=512, smoothness=0.1, merge=False):
|
| 75 |
+
_x = self.forward(x)
|
| 76 |
+
|
| 77 |
+
_xy = _x[:, :, :2]
|
| 78 |
+
_r = _x[:, :, 2]
|
| 79 |
+
_p = _x[:, :, 3]
|
| 80 |
+
|
| 81 |
+
masks = None
|
| 82 |
+
masks_merge = None
|
| 83 |
+
for b in range(_x.size(0)):
|
| 84 |
+
x, y, r = _xy[b, :, 0] * width, _xy[b, :, 1] * width, _r[b] * width / 2
|
| 85 |
+
p = _p[b]
|
| 86 |
+
|
| 87 |
+
mask_list = []
|
| 88 |
+
for i in range(self.num_lights):
|
| 89 |
+
if r[i] < 0 or r[i] > width or p[i] < 0.5:
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
y_coords, x_coords = torch.meshgrid(
|
| 93 |
+
torch.arange(height, device=x.device),
|
| 94 |
+
torch.arange(width, device=x.device),
|
| 95 |
+
indexing="ij",
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
distances = torch.sqrt((x_coords - x[i]) ** 2 + (y_coords - y[i]) ** 2)
|
| 99 |
+
mask_i = torch.sigmoid(smoothness * (r[i] - distances))
|
| 100 |
+
mask_list.append(mask_i)
|
| 101 |
+
|
| 102 |
+
if len(mask_list) == 0:
|
| 103 |
+
_mask_merge = torch.zeros(1, 1, height, width, device=x.device)
|
| 104 |
+
else:
|
| 105 |
+
_mask_merge = torch.stack(mask_list, dim=0).sum(dim=0).unsqueeze(0)
|
| 106 |
+
_mask_merge = _mask_merge.unsqueeze(0)
|
| 107 |
+
|
| 108 |
+
masks_merge = (
|
| 109 |
+
_mask_merge
|
| 110 |
+
if masks_merge is None
|
| 111 |
+
else torch.cat([masks_merge, _mask_merge], dim=0)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
masks_merge = torch.clamp(masks_merge, 0, 1)
|
| 115 |
+
|
| 116 |
+
return masks_merge # [B, 1, H, W]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
if __name__ == "__main__":
|
| 120 |
+
# pydiffvg.set_use_gpu(torch.cuda.is_available())
|
| 121 |
+
model = LightSourceRegressor(num_lights=4).cuda()
|
| 122 |
+
x = torch.randn(8, 3, 512, 512, device="cuda")
|
| 123 |
+
y = model.forward_render(x)
|
| 124 |
+
print(y.shape)
|
src/models/unet.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from torch.nn import init
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class conv_block(nn.Module):
|
| 8 |
+
def __init__(self, ch_in, ch_out):
|
| 9 |
+
super(conv_block, self).__init__()
|
| 10 |
+
self.conv = nn.Sequential(
|
| 11 |
+
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
|
| 12 |
+
nn.BatchNorm2d(ch_out),
|
| 13 |
+
nn.ReLU(inplace=True),
|
| 14 |
+
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
|
| 15 |
+
nn.BatchNorm2d(ch_out),
|
| 16 |
+
nn.ReLU(inplace=True),
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
x = self.conv(x)
|
| 21 |
+
return x
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class up_conv(nn.Module):
|
| 25 |
+
def __init__(self, ch_in, ch_out):
|
| 26 |
+
super(up_conv, self).__init__()
|
| 27 |
+
self.up = nn.Sequential(
|
| 28 |
+
nn.Upsample(scale_factor=2),
|
| 29 |
+
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
|
| 30 |
+
nn.BatchNorm2d(ch_out),
|
| 31 |
+
nn.ReLU(inplace=True),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
x = self.up(x)
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class U_Net(nn.Module):
|
| 40 |
+
def __init__(self, img_ch=3, output_ch=1, multi_stage=False):
|
| 41 |
+
super(U_Net, self).__init__()
|
| 42 |
+
|
| 43 |
+
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 44 |
+
|
| 45 |
+
self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
|
| 46 |
+
self.Conv2 = conv_block(ch_in=64, ch_out=128)
|
| 47 |
+
self.Conv3 = conv_block(ch_in=128, ch_out=256)
|
| 48 |
+
self.Conv4 = conv_block(ch_in=256, ch_out=512)
|
| 49 |
+
self.Conv5 = conv_block(ch_in=512, ch_out=1024)
|
| 50 |
+
|
| 51 |
+
self.Up5 = up_conv(ch_in=1024, ch_out=512)
|
| 52 |
+
self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
|
| 53 |
+
|
| 54 |
+
self.Up4 = up_conv(ch_in=512, ch_out=256)
|
| 55 |
+
self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
|
| 56 |
+
|
| 57 |
+
self.Up3 = up_conv(ch_in=256, ch_out=128)
|
| 58 |
+
self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
|
| 59 |
+
|
| 60 |
+
self.Up2 = up_conv(ch_in=128, ch_out=64)
|
| 61 |
+
self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
|
| 62 |
+
|
| 63 |
+
self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
|
| 64 |
+
self.activation = nn.Sequential(nn.Sigmoid())
|
| 65 |
+
# init_weights(self)
|
| 66 |
+
self.apply(self._init_weights)
|
| 67 |
+
|
| 68 |
+
def _init_weights(self, m):
|
| 69 |
+
init_type = "normal"
|
| 70 |
+
gain = 0.02
|
| 71 |
+
classname = m.__class__.__name__
|
| 72 |
+
if hasattr(m, "weight") and (
|
| 73 |
+
classname.find("Conv") != -1 or classname.find("Linear") != -1
|
| 74 |
+
):
|
| 75 |
+
if init_type == "normal":
|
| 76 |
+
init.normal_(m.weight.data, 0.0, gain)
|
| 77 |
+
elif init_type == "xavier":
|
| 78 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
| 79 |
+
elif init_type == "kaiming":
|
| 80 |
+
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
| 81 |
+
elif init_type == "orthogonal":
|
| 82 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
| 83 |
+
else:
|
| 84 |
+
raise NotImplementedError(
|
| 85 |
+
"initialization method [%s] is not implemented" % init_type
|
| 86 |
+
)
|
| 87 |
+
if hasattr(m, "bias") and m.bias is not None:
|
| 88 |
+
init.constant_(m.bias.data, 0.0)
|
| 89 |
+
elif classname.find("BatchNorm2d") != -1:
|
| 90 |
+
init.normal_(m.weight.data, 1.0, gain)
|
| 91 |
+
init.constant_(m.bias.data, 0.0)
|
| 92 |
+
|
| 93 |
+
def forward(self, x):
|
| 94 |
+
# encoding path
|
| 95 |
+
x1 = self.Conv1(x)
|
| 96 |
+
|
| 97 |
+
x2 = self.Maxpool(x1)
|
| 98 |
+
x2 = self.Conv2(x2)
|
| 99 |
+
|
| 100 |
+
x3 = self.Maxpool(x2)
|
| 101 |
+
x3 = self.Conv3(x3)
|
| 102 |
+
|
| 103 |
+
x4 = self.Maxpool(x3)
|
| 104 |
+
x4 = self.Conv4(x4)
|
| 105 |
+
|
| 106 |
+
x5 = self.Maxpool(x4)
|
| 107 |
+
x5 = self.Conv5(x5)
|
| 108 |
+
|
| 109 |
+
# decoding + concat path
|
| 110 |
+
d5 = self.Up5(x5)
|
| 111 |
+
d5 = torch.cat((x4, d5), dim=1)
|
| 112 |
+
|
| 113 |
+
d5 = self.Up_conv5(d5)
|
| 114 |
+
|
| 115 |
+
d4 = self.Up4(d5)
|
| 116 |
+
d4 = torch.cat((x3, d4), dim=1)
|
| 117 |
+
d4 = self.Up_conv4(d4)
|
| 118 |
+
|
| 119 |
+
d3 = self.Up3(d4)
|
| 120 |
+
d3 = torch.cat((x2, d3), dim=1)
|
| 121 |
+
d3 = self.Up_conv3(d3)
|
| 122 |
+
|
| 123 |
+
d2 = self.Up2(d3)
|
| 124 |
+
d2 = torch.cat((x1, d2), dim=1)
|
| 125 |
+
d2 = self.Up_conv2(d2)
|
| 126 |
+
|
| 127 |
+
d1 = self.Conv_1x1(d2)
|
| 128 |
+
d1 = self.activation(d1)
|
| 129 |
+
return d1
|
src/pipelines/__pycache__/pipeline_controlnet_outpaint.cpython-39.pyc
ADDED
|
Binary file (7.49 kB). View file
|
|
|
src/pipelines/__pycache__/pipeline_stable_diffusion_outpaint.cpython-39.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
src/pipelines/pipeline_controlnet_outpaint.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from typing import List, Union, Dict, Any, Callable, Optional, Tuple
|
| 4 |
+
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel
|
| 5 |
+
from diffusers.utils.torch_utils import randn_tensor, is_compiled_module
|
| 6 |
+
from diffusers.models import ControlNetModel
|
| 7 |
+
from diffusers.pipelines.controlnet import MultiControlNetModel
|
| 8 |
+
from diffusers.image_processor import PipelineImageInput
|
| 9 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import (
|
| 10 |
+
StableDiffusionPipelineOutput,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ControlNetOutpaintPipeline(StableDiffusionControlNetInpaintPipeline):
|
| 15 |
+
@torch.no_grad()
|
| 16 |
+
def __call__(
|
| 17 |
+
self,
|
| 18 |
+
prompt: Union[str, List[str]] = None,
|
| 19 |
+
image: PipelineImageInput = None,
|
| 20 |
+
mask_image: PipelineImageInput = None,
|
| 21 |
+
control_image: PipelineImageInput = None,
|
| 22 |
+
height: Optional[int] = None,
|
| 23 |
+
width: Optional[int] = None,
|
| 24 |
+
strength: float = 1.0,
|
| 25 |
+
num_inference_steps: int = 50,
|
| 26 |
+
guidance_scale: float = 7.5,
|
| 27 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 28 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 29 |
+
eta: float = 0.0,
|
| 30 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 31 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 32 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 33 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 34 |
+
output_type: Optional[str] = "pil",
|
| 35 |
+
return_dict: bool = True,
|
| 36 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 37 |
+
callback_steps: int = 1,
|
| 38 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 39 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
|
| 40 |
+
guess_mode: bool = False,
|
| 41 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
| 42 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
| 43 |
+
clip_skip: Optional[int] = None,
|
| 44 |
+
## add
|
| 45 |
+
repeat_time: int = 4,
|
| 46 |
+
##
|
| 47 |
+
**kwargs: Any,
|
| 48 |
+
):
|
| 49 |
+
r""" """
|
| 50 |
+
controlnet = (
|
| 51 |
+
self.controlnet._orig_mod
|
| 52 |
+
if is_compiled_module(self.controlnet)
|
| 53 |
+
else self.controlnet
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# self.init_filter()
|
| 57 |
+
|
| 58 |
+
# align format for control guidance
|
| 59 |
+
if not isinstance(control_guidance_start, list) and isinstance(
|
| 60 |
+
control_guidance_end, list
|
| 61 |
+
):
|
| 62 |
+
control_guidance_start = len(control_guidance_end) * [
|
| 63 |
+
control_guidance_start
|
| 64 |
+
]
|
| 65 |
+
elif not isinstance(control_guidance_end, list) and isinstance(
|
| 66 |
+
control_guidance_start, list
|
| 67 |
+
):
|
| 68 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
| 69 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(
|
| 70 |
+
control_guidance_end, list
|
| 71 |
+
):
|
| 72 |
+
mult = (
|
| 73 |
+
len(controlnet.nets)
|
| 74 |
+
if isinstance(controlnet, MultiControlNetModel)
|
| 75 |
+
else 1
|
| 76 |
+
)
|
| 77 |
+
control_guidance_start, control_guidance_end = mult * [
|
| 78 |
+
control_guidance_start
|
| 79 |
+
], mult * [control_guidance_end]
|
| 80 |
+
|
| 81 |
+
# 1. Check inputs. Raise error if not correct
|
| 82 |
+
self.check_inputs(
|
| 83 |
+
prompt,
|
| 84 |
+
control_image,
|
| 85 |
+
height,
|
| 86 |
+
width,
|
| 87 |
+
callback_steps,
|
| 88 |
+
negative_prompt,
|
| 89 |
+
prompt_embeds,
|
| 90 |
+
negative_prompt_embeds,
|
| 91 |
+
controlnet_conditioning_scale,
|
| 92 |
+
control_guidance_start,
|
| 93 |
+
control_guidance_end,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# 2. Define call parameters
|
| 97 |
+
if prompt is not None and isinstance(prompt, str):
|
| 98 |
+
batch_size = 1
|
| 99 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 100 |
+
batch_size = len(prompt)
|
| 101 |
+
else:
|
| 102 |
+
batch_size = prompt_embeds.shape[0]
|
| 103 |
+
|
| 104 |
+
device = self._execution_device
|
| 105 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 106 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 107 |
+
# corresponds to doing no classifier free guidance.
|
| 108 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 109 |
+
|
| 110 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(
|
| 111 |
+
controlnet_conditioning_scale, float
|
| 112 |
+
):
|
| 113 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(
|
| 114 |
+
controlnet.nets
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
global_pool_conditions = (
|
| 118 |
+
controlnet.config.global_pool_conditions
|
| 119 |
+
if isinstance(controlnet, ControlNetModel)
|
| 120 |
+
else controlnet.nets[0].config.global_pool_conditions
|
| 121 |
+
)
|
| 122 |
+
guess_mode = guess_mode or global_pool_conditions
|
| 123 |
+
|
| 124 |
+
# 3. Encode input prompt
|
| 125 |
+
text_encoder_lora_scale = (
|
| 126 |
+
cross_attention_kwargs.get("scale", None)
|
| 127 |
+
if cross_attention_kwargs is not None
|
| 128 |
+
else None
|
| 129 |
+
)
|
| 130 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 131 |
+
prompt,
|
| 132 |
+
device,
|
| 133 |
+
num_images_per_prompt,
|
| 134 |
+
do_classifier_free_guidance,
|
| 135 |
+
negative_prompt,
|
| 136 |
+
prompt_embeds=prompt_embeds,
|
| 137 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 138 |
+
lora_scale=text_encoder_lora_scale,
|
| 139 |
+
clip_skip=clip_skip,
|
| 140 |
+
)
|
| 141 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 142 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 143 |
+
# to avoid doing two forward passes
|
| 144 |
+
if do_classifier_free_guidance:
|
| 145 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 146 |
+
|
| 147 |
+
# 4. Prepare image
|
| 148 |
+
if isinstance(controlnet, ControlNetModel):
|
| 149 |
+
control_image = self.prepare_control_image(
|
| 150 |
+
image=control_image,
|
| 151 |
+
width=width,
|
| 152 |
+
height=height,
|
| 153 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 154 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 155 |
+
device=device,
|
| 156 |
+
dtype=controlnet.dtype,
|
| 157 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 158 |
+
guess_mode=guess_mode,
|
| 159 |
+
)
|
| 160 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
| 161 |
+
control_images = []
|
| 162 |
+
|
| 163 |
+
for control_image_ in control_image:
|
| 164 |
+
control_image_ = self.prepare_control_image(
|
| 165 |
+
image=control_image_,
|
| 166 |
+
width=width,
|
| 167 |
+
height=height,
|
| 168 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 169 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 170 |
+
device=device,
|
| 171 |
+
dtype=controlnet.dtype,
|
| 172 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 173 |
+
guess_mode=guess_mode,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
control_images.append(control_image_)
|
| 177 |
+
|
| 178 |
+
control_image = control_images
|
| 179 |
+
else:
|
| 180 |
+
assert False
|
| 181 |
+
|
| 182 |
+
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
|
| 183 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
| 184 |
+
init_image = init_image.to(dtype=torch.float32)
|
| 185 |
+
|
| 186 |
+
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
| 187 |
+
|
| 188 |
+
masked_image = init_image * (mask < 0.5)
|
| 189 |
+
_, _, height, width = init_image.shape
|
| 190 |
+
|
| 191 |
+
# 5. Prepare timesteps
|
| 192 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 193 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
| 194 |
+
num_inference_steps=num_inference_steps, strength=strength, device=device
|
| 195 |
+
)
|
| 196 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
| 197 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 198 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
| 199 |
+
is_strength_max = strength == 1.0
|
| 200 |
+
|
| 201 |
+
# 6. Prepare latent variables
|
| 202 |
+
num_channels_latents = self.vae.config.latent_channels
|
| 203 |
+
num_channels_unet = self.unet.config.in_channels
|
| 204 |
+
return_image_latents = True
|
| 205 |
+
|
| 206 |
+
latents_outputs = self.prepare_latents(
|
| 207 |
+
batch_size * num_images_per_prompt,
|
| 208 |
+
num_channels_latents,
|
| 209 |
+
height,
|
| 210 |
+
width,
|
| 211 |
+
prompt_embeds.dtype,
|
| 212 |
+
device,
|
| 213 |
+
generator,
|
| 214 |
+
latents,
|
| 215 |
+
image=init_image,
|
| 216 |
+
timestep=latent_timestep,
|
| 217 |
+
is_strength_max=is_strength_max,
|
| 218 |
+
return_noise=True,
|
| 219 |
+
return_image_latents=return_image_latents,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if return_image_latents:
|
| 223 |
+
latents, noise, image_latents = latents_outputs
|
| 224 |
+
else:
|
| 225 |
+
latents, noise = latents_outputs
|
| 226 |
+
|
| 227 |
+
# 7. Prepare mask latent variables
|
| 228 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
| 229 |
+
mask,
|
| 230 |
+
masked_image,
|
| 231 |
+
batch_size * num_images_per_prompt,
|
| 232 |
+
height,
|
| 233 |
+
width,
|
| 234 |
+
prompt_embeds.dtype,
|
| 235 |
+
device,
|
| 236 |
+
generator,
|
| 237 |
+
do_classifier_free_guidance,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 241 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 242 |
+
|
| 243 |
+
# 7.1 Create tensor stating which controlnets to keep
|
| 244 |
+
controlnet_keep = []
|
| 245 |
+
for i in range(len(timesteps)):
|
| 246 |
+
keeps = [
|
| 247 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
| 248 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
| 249 |
+
]
|
| 250 |
+
controlnet_keep.append(
|
| 251 |
+
keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# 8. Denoising loop
|
| 255 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 256 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 257 |
+
# for i, t in enumerate(timesteps):
|
| 258 |
+
|
| 259 |
+
## modify
|
| 260 |
+
i = 0
|
| 261 |
+
reinject = repeat_time
|
| 262 |
+
while i < len(timesteps):
|
| 263 |
+
# expand the latents if we are doing classifier free guidance
|
| 264 |
+
t = timesteps[i]
|
| 265 |
+
latent_model_input = (
|
| 266 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 267 |
+
)
|
| 268 |
+
latent_model_input = self.scheduler.scale_model_input(
|
| 269 |
+
latent_model_input, t
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# controlnet(s) inference
|
| 273 |
+
if guess_mode and do_classifier_free_guidance:
|
| 274 |
+
# Infer ControlNet only for the conditional batch.
|
| 275 |
+
control_model_input = latents
|
| 276 |
+
control_model_input = self.scheduler.scale_model_input(
|
| 277 |
+
control_model_input, t
|
| 278 |
+
)
|
| 279 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
| 280 |
+
else:
|
| 281 |
+
control_model_input = latent_model_input
|
| 282 |
+
controlnet_prompt_embeds = prompt_embeds
|
| 283 |
+
|
| 284 |
+
if isinstance(controlnet_keep[i], list):
|
| 285 |
+
cond_scale = [
|
| 286 |
+
c * s
|
| 287 |
+
for c, s in zip(
|
| 288 |
+
controlnet_conditioning_scale, controlnet_keep[i]
|
| 289 |
+
)
|
| 290 |
+
]
|
| 291 |
+
else:
|
| 292 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
| 293 |
+
if isinstance(controlnet_cond_scale, list):
|
| 294 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
| 295 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
| 296 |
+
|
| 297 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 298 |
+
control_model_input,
|
| 299 |
+
t,
|
| 300 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
| 301 |
+
controlnet_cond=control_image,
|
| 302 |
+
conditioning_scale=cond_scale,
|
| 303 |
+
guess_mode=guess_mode,
|
| 304 |
+
return_dict=False,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if guess_mode and do_classifier_free_guidance:
|
| 308 |
+
# Infered ControlNet only for the conditional batch.
|
| 309 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
| 310 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
| 311 |
+
down_block_res_samples = [
|
| 312 |
+
torch.cat([torch.zeros_like(d), d])
|
| 313 |
+
for d in down_block_res_samples
|
| 314 |
+
]
|
| 315 |
+
mid_block_res_sample = torch.cat(
|
| 316 |
+
[
|
| 317 |
+
torch.zeros_like(mid_block_res_sample),
|
| 318 |
+
mid_block_res_sample,
|
| 319 |
+
]
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# predict the noise residual
|
| 323 |
+
if num_channels_unet == 9:
|
| 324 |
+
latent_model_input = torch.cat(
|
| 325 |
+
[latent_model_input, mask, masked_image_latents], dim=1
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
noise_pred = self.unet(
|
| 329 |
+
latent_model_input,
|
| 330 |
+
t,
|
| 331 |
+
encoder_hidden_states=prompt_embeds,
|
| 332 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 333 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 334 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 335 |
+
return_dict=False,
|
| 336 |
+
)[0]
|
| 337 |
+
|
| 338 |
+
# perform guidance
|
| 339 |
+
if do_classifier_free_guidance:
|
| 340 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 341 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 342 |
+
noise_pred_text - noise_pred_uncond
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 346 |
+
latents = self.scheduler.step(
|
| 347 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
| 348 |
+
)[0]
|
| 349 |
+
|
| 350 |
+
if num_channels_unet == 4:
|
| 351 |
+
init_latents_proper = image_latents
|
| 352 |
+
if do_classifier_free_guidance:
|
| 353 |
+
init_mask, _ = mask.chunk(2)
|
| 354 |
+
else:
|
| 355 |
+
init_mask = mask
|
| 356 |
+
|
| 357 |
+
if i < len(timesteps) - 1:
|
| 358 |
+
noise_timestep = timesteps[i + 1]
|
| 359 |
+
init_latents_proper = self.scheduler.add_noise(
|
| 360 |
+
init_latents_proper,
|
| 361 |
+
noise,
|
| 362 |
+
torch.tensor([noise_timestep]),
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
latents = (
|
| 366 |
+
1 - init_mask
|
| 367 |
+
) * init_latents_proper + init_mask * latents
|
| 368 |
+
|
| 369 |
+
i += 1
|
| 370 |
+
|
| 371 |
+
## noise reinjection
|
| 372 |
+
if i > 0 and i < int(len(timesteps) - 1) and reinject > 0:
|
| 373 |
+
current_timestep = timesteps[i]
|
| 374 |
+
target_timestep = timesteps[i - 1]
|
| 375 |
+
new_nosie = torch.randn_like(latents)
|
| 376 |
+
|
| 377 |
+
# step back x_t-1 -> x_t
|
| 378 |
+
latents = self.scheduler.step_back(
|
| 379 |
+
latents,
|
| 380 |
+
new_nosie,
|
| 381 |
+
torch.tensor([current_timestep]),
|
| 382 |
+
torch.tensor([target_timestep]),
|
| 383 |
+
)
|
| 384 |
+
i -= 1
|
| 385 |
+
reinject -= 1
|
| 386 |
+
else:
|
| 387 |
+
# reinject = repeat_time
|
| 388 |
+
|
| 389 |
+
# schedule
|
| 390 |
+
if i >= int(len(timesteps) * 0.8):
|
| 391 |
+
reinject = 0
|
| 392 |
+
elif i >= int(len(timesteps) * 0.6):
|
| 393 |
+
reinject = max(0, repeat_time - 3)
|
| 394 |
+
elif i >= int(len(timesteps) * 0.4):
|
| 395 |
+
reinject = max(0, repeat_time - 2)
|
| 396 |
+
elif i >= int(len(timesteps) * 0.2):
|
| 397 |
+
reinject = max(0, repeat_time - 1)
|
| 398 |
+
else:
|
| 399 |
+
reinject = repeat_time
|
| 400 |
+
|
| 401 |
+
# call the callback, if provided
|
| 402 |
+
if i == len(timesteps) - 1 or (
|
| 403 |
+
(i + 1) > num_warmup_steps
|
| 404 |
+
and (i + 1) % self.scheduler.order == 0
|
| 405 |
+
):
|
| 406 |
+
progress_bar.update()
|
| 407 |
+
if callback is not None and i % callback_steps == 0:
|
| 408 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 409 |
+
callback(step_idx, t, latents)
|
| 410 |
+
|
| 411 |
+
# If we do sequential model offloading, let's offload unet and controlnet
|
| 412 |
+
# manually for max memory savings
|
| 413 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 414 |
+
self.unet.to("cpu")
|
| 415 |
+
self.controlnet.to("cpu")
|
| 416 |
+
torch.cuda.empty_cache()
|
| 417 |
+
|
| 418 |
+
if not output_type == "latent":
|
| 419 |
+
image = self.vae.decode(
|
| 420 |
+
latents / self.vae.config.scaling_factor,
|
| 421 |
+
return_dict=False,
|
| 422 |
+
generator=generator,
|
| 423 |
+
)[0]
|
| 424 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
| 425 |
+
image, device, prompt_embeds.dtype
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
image = latents
|
| 429 |
+
has_nsfw_concept = None
|
| 430 |
+
|
| 431 |
+
if has_nsfw_concept is None:
|
| 432 |
+
do_denormalize = [True] * image.shape[0]
|
| 433 |
+
else:
|
| 434 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 435 |
+
|
| 436 |
+
image = self.image_processor.postprocess(
|
| 437 |
+
image, output_type=output_type, do_denormalize=do_denormalize
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Offload all models
|
| 441 |
+
self.maybe_free_model_hooks()
|
| 442 |
+
|
| 443 |
+
if not return_dict:
|
| 444 |
+
return (image, has_nsfw_concept)
|
| 445 |
+
|
| 446 |
+
return StableDiffusionPipelineOutput(
|
| 447 |
+
images=image, nsfw_content_detected=has_nsfw_concept
|
| 448 |
+
)
|
src/pipelines/pipeline_stable_diffusion_outpaint.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from typing import List, Union, Dict, Any, Callable, Optional, Tuple
|
| 4 |
+
from diffusers import StableDiffusionInpaintPipeline
|
| 5 |
+
from diffusers.utils import make_image_grid, load_image, deprecate
|
| 6 |
+
from diffusers.models import AsymmetricAutoencoderKL
|
| 7 |
+
from diffusers.image_processor import PipelineImageInput
|
| 8 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import (
|
| 9 |
+
StableDiffusionPipelineOutput,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class OutpaintPipeline(StableDiffusionInpaintPipeline):
|
| 14 |
+
@torch.no_grad()
|
| 15 |
+
def __call__(
|
| 16 |
+
self,
|
| 17 |
+
prompt: Union[str, List[str]] = None,
|
| 18 |
+
image: PipelineImageInput = None,
|
| 19 |
+
mask_image: PipelineImageInput = None,
|
| 20 |
+
control_image: PipelineImageInput = None,
|
| 21 |
+
masked_image_latents: torch.FloatTensor = None,
|
| 22 |
+
height: Optional[int] = None,
|
| 23 |
+
width: Optional[int] = None,
|
| 24 |
+
strength: float = 1.0,
|
| 25 |
+
num_inference_steps: int = 50,
|
| 26 |
+
guidance_scale: float = 7.5,
|
| 27 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 28 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 29 |
+
eta: float = 0.0,
|
| 30 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 31 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 32 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 33 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 34 |
+
output_type: Optional[str] = "pil",
|
| 35 |
+
return_dict: bool = True,
|
| 36 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 37 |
+
clip_skip: int = None,
|
| 38 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 39 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 40 |
+
## add
|
| 41 |
+
repeat_time: int = 4,
|
| 42 |
+
##
|
| 43 |
+
**kwargs,
|
| 44 |
+
):
|
| 45 |
+
r"""
|
| 46 |
+
The call function to the pipeline for generation.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 50 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 51 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 52 |
+
`Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
|
| 53 |
+
be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
|
| 54 |
+
tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
|
| 55 |
+
expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
|
| 56 |
+
expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
|
| 57 |
+
if passing latents directly it is not encoded again.
|
| 58 |
+
mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 59 |
+
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
|
| 60 |
+
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
|
| 61 |
+
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
|
| 62 |
+
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
|
| 63 |
+
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
|
| 64 |
+
1)`, or `(H, W)`.
|
| 65 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 66 |
+
The height in pixels of the generated image.
|
| 67 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 68 |
+
The width in pixels of the generated image.
|
| 69 |
+
strength (`float`, *optional*, defaults to 1.0):
|
| 70 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
| 71 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
| 72 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
| 73 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
| 74 |
+
essentially ignores `image`.
|
| 75 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 76 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 77 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
| 78 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 79 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 80 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 81 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 82 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 83 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 84 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 85 |
+
The number of images to generate per prompt.
|
| 86 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 87 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 88 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 89 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 90 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 91 |
+
generation deterministic.
|
| 92 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 93 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 94 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 95 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 96 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 97 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 98 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 99 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 100 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 101 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 102 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 103 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 104 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 105 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 106 |
+
plain tuple.
|
| 107 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 108 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 109 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 110 |
+
clip_skip (`int`, *optional*):
|
| 111 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 112 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 113 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 114 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 115 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 116 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 117 |
+
`callback_on_step_end_tensor_inputs`.
|
| 118 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 119 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 120 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 121 |
+
`._callback_tensor_inputs` attribute of your pipeine class.
|
| 122 |
+
Examples:
|
| 123 |
+
|
| 124 |
+
```py
|
| 125 |
+
>>> import PIL
|
| 126 |
+
>>> import requests
|
| 127 |
+
>>> import torch
|
| 128 |
+
>>> from io import BytesIO
|
| 129 |
+
|
| 130 |
+
>>> from diffusers import StableDiffusionInpaintPipeline
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
>>> def download_image(url):
|
| 134 |
+
... response = requests.get(url)
|
| 135 |
+
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
| 139 |
+
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
| 140 |
+
|
| 141 |
+
>>> init_image = download_image(img_url).resize((512, 512))
|
| 142 |
+
>>> mask_image = download_image(mask_url).resize((512, 512))
|
| 143 |
+
|
| 144 |
+
>>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
| 145 |
+
... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
|
| 146 |
+
... )
|
| 147 |
+
>>> pipe = pipe.to("cuda")
|
| 148 |
+
|
| 149 |
+
>>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
| 150 |
+
>>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 155 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 156 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 157 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 158 |
+
"not-safe-for-work" (nsfw) content.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
callback = kwargs.pop("callback", None)
|
| 162 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 163 |
+
|
| 164 |
+
if callback is not None:
|
| 165 |
+
deprecate(
|
| 166 |
+
"callback",
|
| 167 |
+
"1.0.0",
|
| 168 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 169 |
+
)
|
| 170 |
+
if callback_steps is not None:
|
| 171 |
+
deprecate(
|
| 172 |
+
"callback_steps",
|
| 173 |
+
"1.0.0",
|
| 174 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# 0. Default height and width to unet
|
| 178 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 179 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 180 |
+
|
| 181 |
+
# 1. Check inputs
|
| 182 |
+
self.check_inputs(
|
| 183 |
+
prompt,
|
| 184 |
+
height,
|
| 185 |
+
width,
|
| 186 |
+
strength,
|
| 187 |
+
callback_steps,
|
| 188 |
+
negative_prompt,
|
| 189 |
+
prompt_embeds,
|
| 190 |
+
negative_prompt_embeds,
|
| 191 |
+
callback_on_step_end_tensor_inputs,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
self._guidance_scale = guidance_scale
|
| 195 |
+
self._clip_skip = clip_skip
|
| 196 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 197 |
+
|
| 198 |
+
# 2. Define call parameters
|
| 199 |
+
if prompt is not None and isinstance(prompt, str):
|
| 200 |
+
batch_size = 1
|
| 201 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 202 |
+
batch_size = len(prompt)
|
| 203 |
+
else:
|
| 204 |
+
batch_size = prompt_embeds.shape[0]
|
| 205 |
+
|
| 206 |
+
device = self._execution_device
|
| 207 |
+
|
| 208 |
+
# 3. Encode input prompt
|
| 209 |
+
text_encoder_lora_scale = (
|
| 210 |
+
cross_attention_kwargs.get("scale", None)
|
| 211 |
+
if cross_attention_kwargs is not None
|
| 212 |
+
else None
|
| 213 |
+
)
|
| 214 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 215 |
+
prompt,
|
| 216 |
+
device,
|
| 217 |
+
num_images_per_prompt,
|
| 218 |
+
self.do_classifier_free_guidance,
|
| 219 |
+
negative_prompt,
|
| 220 |
+
prompt_embeds=prompt_embeds,
|
| 221 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 222 |
+
lora_scale=text_encoder_lora_scale,
|
| 223 |
+
clip_skip=self.clip_skip,
|
| 224 |
+
)
|
| 225 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 226 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 227 |
+
# to avoid doing two forward passes
|
| 228 |
+
if self.do_classifier_free_guidance:
|
| 229 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 230 |
+
|
| 231 |
+
# 4. set timesteps
|
| 232 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 233 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
| 234 |
+
num_inference_steps=num_inference_steps, strength=strength, device=device
|
| 235 |
+
)
|
| 236 |
+
# check that number of inference steps is not < 1 - as this doesn't make sense
|
| 237 |
+
if num_inference_steps < 1:
|
| 238 |
+
raise ValueError(
|
| 239 |
+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
| 240 |
+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
| 241 |
+
)
|
| 242 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
| 243 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 244 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
| 245 |
+
is_strength_max = strength == 1.0
|
| 246 |
+
|
| 247 |
+
# 5. Preprocess mask and image
|
| 248 |
+
|
| 249 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
| 250 |
+
init_image = init_image.to(dtype=torch.float32)
|
| 251 |
+
|
| 252 |
+
# 6. Prepare latent variables
|
| 253 |
+
num_channels_latents = self.vae.config.latent_channels
|
| 254 |
+
num_channels_unet = self.unet.config.in_channels
|
| 255 |
+
return_image_latents = num_channels_unet == 4
|
| 256 |
+
|
| 257 |
+
latents_outputs = self.prepare_latents(
|
| 258 |
+
batch_size * num_images_per_prompt,
|
| 259 |
+
num_channels_latents,
|
| 260 |
+
height,
|
| 261 |
+
width,
|
| 262 |
+
prompt_embeds.dtype,
|
| 263 |
+
device,
|
| 264 |
+
generator,
|
| 265 |
+
latents,
|
| 266 |
+
image=init_image,
|
| 267 |
+
timestep=latent_timestep,
|
| 268 |
+
is_strength_max=is_strength_max,
|
| 269 |
+
return_noise=True,
|
| 270 |
+
return_image_latents=return_image_latents,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
if return_image_latents:
|
| 274 |
+
latents, noise, image_latents = latents_outputs
|
| 275 |
+
else:
|
| 276 |
+
latents, noise = latents_outputs
|
| 277 |
+
|
| 278 |
+
# 7. Prepare mask latent variables
|
| 279 |
+
mask_condition = self.mask_processor.preprocess(
|
| 280 |
+
mask_image, height=height, width=width
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if masked_image_latents is None:
|
| 284 |
+
masked_image = init_image * (mask_condition < 0.5)
|
| 285 |
+
else:
|
| 286 |
+
masked_image = masked_image_latents
|
| 287 |
+
|
| 288 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
| 289 |
+
mask_condition,
|
| 290 |
+
masked_image,
|
| 291 |
+
batch_size * num_images_per_prompt,
|
| 292 |
+
height,
|
| 293 |
+
width,
|
| 294 |
+
prompt_embeds.dtype,
|
| 295 |
+
device,
|
| 296 |
+
generator,
|
| 297 |
+
self.do_classifier_free_guidance,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# 8. Check that sizes of mask, masked image and latents match
|
| 301 |
+
if num_channels_unet == 9:
|
| 302 |
+
# default case for runwayml/stable-diffusion-inpainting
|
| 303 |
+
num_channels_mask = mask.shape[1]
|
| 304 |
+
num_channels_masked_image = masked_image_latents.shape[1]
|
| 305 |
+
if (
|
| 306 |
+
num_channels_latents + num_channels_mask + num_channels_masked_image
|
| 307 |
+
!= self.unet.config.in_channels
|
| 308 |
+
):
|
| 309 |
+
raise ValueError(
|
| 310 |
+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
| 311 |
+
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
| 312 |
+
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
| 313 |
+
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
| 314 |
+
" `pipeline.unet` or your `mask_image` or `image` input."
|
| 315 |
+
)
|
| 316 |
+
elif num_channels_unet != 4:
|
| 317 |
+
raise ValueError(
|
| 318 |
+
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 322 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 323 |
+
|
| 324 |
+
# 9.5 Optionally get Guidance Scale Embedding
|
| 325 |
+
timestep_cond = None
|
| 326 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 327 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
|
| 328 |
+
batch_size * num_images_per_prompt
|
| 329 |
+
)
|
| 330 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 331 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 332 |
+
).to(device=device, dtype=latents.dtype)
|
| 333 |
+
|
| 334 |
+
# 10. Denoising loop
|
| 335 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 336 |
+
self._num_timesteps = len(timesteps)
|
| 337 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 338 |
+
# for i in range(len(timesteps)):
|
| 339 |
+
|
| 340 |
+
## modify
|
| 341 |
+
i = 0
|
| 342 |
+
reinject = repeat_time
|
| 343 |
+
while i < len(timesteps):
|
| 344 |
+
# expand the latents if we are doing classifier free guidance
|
| 345 |
+
latent_model_input = (
|
| 346 |
+
torch.cat([latents] * 2)
|
| 347 |
+
if self.do_classifier_free_guidance
|
| 348 |
+
else latents
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# concat latents, mask, masked_image_latents in the channel dimension
|
| 352 |
+
latent_model_input = self.scheduler.scale_model_input(
|
| 353 |
+
latent_model_input, timesteps[i]
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
if num_channels_unet == 9:
|
| 357 |
+
latent_model_input = torch.cat(
|
| 358 |
+
[latent_model_input, mask, masked_image_latents], dim=1
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# predict the noise residual
|
| 362 |
+
noise_pred = self.unet(
|
| 363 |
+
latent_model_input,
|
| 364 |
+
timesteps[i],
|
| 365 |
+
encoder_hidden_states=prompt_embeds,
|
| 366 |
+
timestep_cond=timestep_cond,
|
| 367 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 368 |
+
return_dict=False,
|
| 369 |
+
)[0]
|
| 370 |
+
|
| 371 |
+
# perform guidance
|
| 372 |
+
if self.do_classifier_free_guidance:
|
| 373 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 374 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
| 375 |
+
noise_pred_text - noise_pred_uncond
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 379 |
+
latents = self.scheduler.step(
|
| 380 |
+
noise_pred,
|
| 381 |
+
timesteps[i],
|
| 382 |
+
latents,
|
| 383 |
+
**extra_step_kwargs,
|
| 384 |
+
return_dict=False,
|
| 385 |
+
)[0]
|
| 386 |
+
if num_channels_unet == 4:
|
| 387 |
+
init_latents_proper = image_latents
|
| 388 |
+
if self.do_classifier_free_guidance:
|
| 389 |
+
init_mask, _ = mask.chunk(2)
|
| 390 |
+
else:
|
| 391 |
+
init_mask = mask
|
| 392 |
+
|
| 393 |
+
if i < len(timesteps) - 1:
|
| 394 |
+
noise_timestep = timesteps[i + 1]
|
| 395 |
+
init_latents_proper = self.scheduler.add_noise(
|
| 396 |
+
init_latents_proper, noise, torch.tensor([noise_timestep])
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
latents = (
|
| 400 |
+
1 - init_mask
|
| 401 |
+
) * init_latents_proper + init_mask * latents
|
| 402 |
+
|
| 403 |
+
if callback_on_step_end is not None:
|
| 404 |
+
callback_kwargs = {}
|
| 405 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 406 |
+
callback_kwargs[k] = locals()[k]
|
| 407 |
+
callback_outputs = callback_on_step_end(
|
| 408 |
+
self, i, timesteps[i], callback_kwargs
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
latents = callback_outputs.pop("latents", latents)
|
| 412 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 413 |
+
negative_prompt_embeds = callback_outputs.pop(
|
| 414 |
+
"negative_prompt_embeds", negative_prompt_embeds
|
| 415 |
+
)
|
| 416 |
+
mask = callback_outputs.pop("mask", mask)
|
| 417 |
+
masked_image_latents = callback_outputs.pop(
|
| 418 |
+
"masked_image_latents", masked_image_latents
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# # call the callback, if provided
|
| 422 |
+
# if i == len(timesteps) - 1 or (
|
| 423 |
+
# (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
| 424 |
+
# ):
|
| 425 |
+
# progress_bar.update()
|
| 426 |
+
# if callback is not None and i % callback_steps == 0:
|
| 427 |
+
# step_idx = i // getattr(self.scheduler, "order", 1)
|
| 428 |
+
# callback(step_idx, timesteps[i], latents)
|
| 429 |
+
|
| 430 |
+
i += 1
|
| 431 |
+
|
| 432 |
+
## noise reinjection
|
| 433 |
+
if i > 0 and i < int(len(timesteps) - 1) and reinject != 0:
|
| 434 |
+
current_timestep = timesteps[i]
|
| 435 |
+
target_timestep = timesteps[i - 1]
|
| 436 |
+
new_nosie = torch.randn_like(latents)
|
| 437 |
+
|
| 438 |
+
# step back x_t-1 -> x_t
|
| 439 |
+
latents = self.scheduler.step_back(
|
| 440 |
+
latents,
|
| 441 |
+
new_nosie,
|
| 442 |
+
torch.tensor([current_timestep]),
|
| 443 |
+
torch.tensor([target_timestep]),
|
| 444 |
+
)
|
| 445 |
+
i -= 1
|
| 446 |
+
reinject -= 1
|
| 447 |
+
else:
|
| 448 |
+
# reinject = repeat_time
|
| 449 |
+
|
| 450 |
+
# schedule
|
| 451 |
+
if i >= int(len(timesteps) * 0.85):
|
| 452 |
+
reinject = 0
|
| 453 |
+
elif i >= int(len(timesteps) * 0.8):
|
| 454 |
+
reinject = 1
|
| 455 |
+
elif i >= int(len(timesteps) * 0.7):
|
| 456 |
+
reinject = 2
|
| 457 |
+
elif i >= int(len(timesteps) * 0.5):
|
| 458 |
+
reinject = 3
|
| 459 |
+
else:
|
| 460 |
+
reinject = 4
|
| 461 |
+
|
| 462 |
+
# call the callback, if provided
|
| 463 |
+
if i == len(timesteps) - 1 or (
|
| 464 |
+
(i + 1) > num_warmup_steps
|
| 465 |
+
and (i + 1) % self.scheduler.order == 0
|
| 466 |
+
):
|
| 467 |
+
progress_bar.update()
|
| 468 |
+
if callback is not None and i % callback_steps == 0:
|
| 469 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 470 |
+
callback(step_idx, timesteps[i], latents)
|
| 471 |
+
|
| 472 |
+
if not output_type == "latent":
|
| 473 |
+
condition_kwargs = {}
|
| 474 |
+
if isinstance(self.vae, AsymmetricAutoencoderKL):
|
| 475 |
+
init_image = init_image.to(
|
| 476 |
+
device=device, dtype=masked_image_latents.dtype
|
| 477 |
+
)
|
| 478 |
+
init_image_condition = init_image.clone()
|
| 479 |
+
init_image = self._encode_vae_image(init_image, generator=generator)
|
| 480 |
+
mask_condition = mask_condition.to(
|
| 481 |
+
device=device, dtype=masked_image_latents.dtype
|
| 482 |
+
)
|
| 483 |
+
condition_kwargs = {
|
| 484 |
+
"image": init_image_condition,
|
| 485 |
+
"mask": mask_condition,
|
| 486 |
+
}
|
| 487 |
+
image = self.vae.decode(
|
| 488 |
+
latents / self.vae.config.scaling_factor,
|
| 489 |
+
return_dict=False,
|
| 490 |
+
generator=generator,
|
| 491 |
+
**condition_kwargs,
|
| 492 |
+
)[0]
|
| 493 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
| 494 |
+
image, device, prompt_embeds.dtype
|
| 495 |
+
)
|
| 496 |
+
else:
|
| 497 |
+
image = latents
|
| 498 |
+
has_nsfw_concept = None
|
| 499 |
+
|
| 500 |
+
if has_nsfw_concept is None:
|
| 501 |
+
do_denormalize = [True] * image.shape[0]
|
| 502 |
+
else:
|
| 503 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 504 |
+
|
| 505 |
+
image = self.image_processor.postprocess(
|
| 506 |
+
image, output_type=output_type, do_denormalize=do_denormalize
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
# Offload all models
|
| 510 |
+
self.maybe_free_model_hooks()
|
| 511 |
+
|
| 512 |
+
if not return_dict:
|
| 513 |
+
return (image, has_nsfw_concept)
|
| 514 |
+
|
| 515 |
+
return StableDiffusionPipelineOutput(
|
| 516 |
+
images=image, nsfw_content_detected=has_nsfw_concept
|
| 517 |
+
)
|
src/schedulers/__pycache__/scheduling_pndm.cpython-39.pyc
ADDED
|
Binary file (3.73 kB). View file
|
|
|
src/schedulers/scheduling_pndm.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
from diffusers import PNDMScheduler
|
| 4 |
+
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CustomScheduler(PNDMScheduler):
|
| 8 |
+
def step_plms(
|
| 9 |
+
self,
|
| 10 |
+
model_output: torch.FloatTensor,
|
| 11 |
+
timestep: int,
|
| 12 |
+
sample: torch.FloatTensor,
|
| 13 |
+
return_dict: bool = True,
|
| 14 |
+
) -> Union[SchedulerOutput, Tuple]:
|
| 15 |
+
"""
|
| 16 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
| 17 |
+
the linear multistep method. It performs one forward pass multiple times to approximate the solution.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
model_output (`torch.FloatTensor`):
|
| 21 |
+
The direct output from learned diffusion model.
|
| 22 |
+
timestep (`int`):
|
| 23 |
+
The current discrete timestep in the diffusion chain.
|
| 24 |
+
sample (`torch.FloatTensor`):
|
| 25 |
+
A current instance of a sample created by the diffusion process.
|
| 26 |
+
return_dict (`bool`):
|
| 27 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
| 31 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
| 32 |
+
tuple is returned where the first element is the sample tensor.
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
if self.num_inference_steps is None:
|
| 36 |
+
raise ValueError(
|
| 37 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
if not self.config.skip_prk_steps and len(self.ets) < 3:
|
| 41 |
+
raise ValueError(
|
| 42 |
+
f"{self.__class__} can only be run AFTER scheduler has been run "
|
| 43 |
+
"in 'prk' mode for at least 12 iterations "
|
| 44 |
+
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
|
| 45 |
+
"for more information."
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
prev_timestep = (
|
| 49 |
+
timestep - self.config.num_train_timesteps // self.num_inference_steps
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
if self.counter != 1:
|
| 53 |
+
self.ets = self.ets[-3:]
|
| 54 |
+
self.ets.append(model_output)
|
| 55 |
+
else:
|
| 56 |
+
prev_timestep = timestep
|
| 57 |
+
timestep = (
|
| 58 |
+
timestep + self.config.num_train_timesteps // self.num_inference_steps
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if len(self.ets) == 1 and self.counter == 0:
|
| 62 |
+
model_output = model_output
|
| 63 |
+
self.cur_sample = sample
|
| 64 |
+
elif len(self.ets) == 1 and self.counter == 1:
|
| 65 |
+
model_output = (model_output + self.ets[-1]) / 2
|
| 66 |
+
sample = self.cur_sample
|
| 67 |
+
# self.cur_sample = None
|
| 68 |
+
elif len(self.ets) == 2:
|
| 69 |
+
model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
|
| 70 |
+
elif len(self.ets) == 3:
|
| 71 |
+
model_output = (
|
| 72 |
+
23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]
|
| 73 |
+
) / 12
|
| 74 |
+
else:
|
| 75 |
+
model_output = (1 / 24) * (
|
| 76 |
+
55 * self.ets[-1]
|
| 77 |
+
- 59 * self.ets[-2]
|
| 78 |
+
+ 37 * self.ets[-3]
|
| 79 |
+
- 9 * self.ets[-4]
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
prev_sample = self._get_prev_sample(
|
| 83 |
+
sample, timestep, prev_timestep, model_output
|
| 84 |
+
)
|
| 85 |
+
self.counter += 1
|
| 86 |
+
|
| 87 |
+
if not return_dict:
|
| 88 |
+
return (prev_sample,)
|
| 89 |
+
|
| 90 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
| 91 |
+
|
| 92 |
+
def step_back(
|
| 93 |
+
self,
|
| 94 |
+
current_samples: torch.FloatTensor,
|
| 95 |
+
noise: torch.FloatTensor,
|
| 96 |
+
current_timesteps: torch.IntTensor,
|
| 97 |
+
target_timesteps: torch.IntTensor,
|
| 98 |
+
):
|
| 99 |
+
"""Custom function for stepping back in the diffusion process."""
|
| 100 |
+
|
| 101 |
+
assert current_timesteps <= target_timesteps
|
| 102 |
+
alphas_cumprod = self.alphas_cumprod.to(
|
| 103 |
+
device=current_samples.device, dtype=current_samples.dtype
|
| 104 |
+
)
|
| 105 |
+
target_timesteps = target_timesteps.to(current_samples.device)
|
| 106 |
+
current_timesteps = current_timesteps.to(current_samples.device)
|
| 107 |
+
alpha_prod_target = alphas_cumprod[target_timesteps]
|
| 108 |
+
alpha_prod_target = alpha_prod_target.flatten()
|
| 109 |
+
alpha_prod_current = alphas_cumprod[current_timesteps]
|
| 110 |
+
alpha_prod_current = alpha_prod_current.flatten()
|
| 111 |
+
alpha_prod = alpha_prod_target / alpha_prod_current
|
| 112 |
+
|
| 113 |
+
sqrt_alpha_prod = alpha_prod**0.5
|
| 114 |
+
sqrt_one_minus_alpha_prod = (1 - alpha_prod) ** 0.5
|
| 115 |
+
|
| 116 |
+
while len(sqrt_alpha_prod.shape) < len(current_samples.shape):
|
| 117 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
| 118 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(current_samples.shape):
|
| 119 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
| 120 |
+
|
| 121 |
+
noisy_samples = (
|
| 122 |
+
sqrt_alpha_prod * current_samples + sqrt_one_minus_alpha_prod * noise
|
| 123 |
+
)
|
| 124 |
+
self.counter -= 1
|
| 125 |
+
|
| 126 |
+
return noisy_samples
|
utils/__pycache__/dataset.cpython-39.pyc
ADDED
|
Binary file (28.2 kB). View file
|
|
|
utils/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (8.55 kB). View file
|
|
|
utils/dataset.py
ADDED
|
@@ -0,0 +1,1304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import glob
|
| 4 |
+
import random
|
| 5 |
+
import timeit
|
| 6 |
+
import numpy as np
|
| 7 |
+
import skimage
|
| 8 |
+
import yaml
|
| 9 |
+
import torch
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
import torchvision.transforms.functional as TF
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
from torch.distributions import Normal
|
| 15 |
+
|
| 16 |
+
# from utils.utils import RGB2YCbCr
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RandomGammaCorrection(object):
|
| 20 |
+
def __init__(self, gamma=None):
|
| 21 |
+
self.gamma = gamma
|
| 22 |
+
|
| 23 |
+
def __call__(self, image):
|
| 24 |
+
if self.gamma == None:
|
| 25 |
+
# more chances of selecting 0 (original image)
|
| 26 |
+
gammas = [0.5, 1, 2]
|
| 27 |
+
self.gamma = random.choice(gammas)
|
| 28 |
+
return TF.adjust_gamma(image, self.gamma, gain=1)
|
| 29 |
+
elif isinstance(self.gamma, tuple):
|
| 30 |
+
gamma = random.uniform(*self.gamma)
|
| 31 |
+
return TF.adjust_gamma(image, gamma, gain=1)
|
| 32 |
+
elif self.gamma == 0:
|
| 33 |
+
return image
|
| 34 |
+
else:
|
| 35 |
+
return TF.adjust_gamma(image, self.gamma, gain=1)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def remove_background(image):
|
| 39 |
+
# the input of the image is PIL.Image form with [H,W,C]
|
| 40 |
+
image = np.float32(np.array(image))
|
| 41 |
+
_EPS = 1e-7
|
| 42 |
+
rgb_max = np.max(image, (0, 1))
|
| 43 |
+
rgb_min = np.min(image, (0, 1))
|
| 44 |
+
image = (image - rgb_min) * rgb_max / (rgb_max - rgb_min + _EPS)
|
| 45 |
+
image = torch.from_numpy(image)
|
| 46 |
+
return image
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def glod_from_folder(folder_list, index_list):
|
| 50 |
+
ext = ["png", "jpeg", "jpg", "bmp", "tif"]
|
| 51 |
+
index_dict = {}
|
| 52 |
+
for i, folder_name in enumerate(folder_list):
|
| 53 |
+
data_list = []
|
| 54 |
+
[data_list.extend(glob.glob(folder_name + "/*." + e)) for e in ext]
|
| 55 |
+
data_list.sort()
|
| 56 |
+
index_dict[index_list[i]] = data_list
|
| 57 |
+
return index_dict
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Flare_Image_Loader(Dataset):
|
| 61 |
+
def __init__(self, image_path, transform_base, transform_flare, mask_type=None):
|
| 62 |
+
self.ext = ["png", "jpeg", "jpg", "bmp", "tif"]
|
| 63 |
+
self.data_list = []
|
| 64 |
+
[self.data_list.extend(glob.glob(image_path + "/*." + e)) for e in self.ext]
|
| 65 |
+
self.flare_dict = {}
|
| 66 |
+
self.flare_list = []
|
| 67 |
+
self.flare_name_list = []
|
| 68 |
+
|
| 69 |
+
self.reflective_flag = False
|
| 70 |
+
self.reflective_dict = {}
|
| 71 |
+
self.reflective_list = []
|
| 72 |
+
self.reflective_name_list = []
|
| 73 |
+
|
| 74 |
+
self.light_flag = False
|
| 75 |
+
self.light_dict = {}
|
| 76 |
+
self.light_list = []
|
| 77 |
+
self.light_name_list = []
|
| 78 |
+
|
| 79 |
+
self.mask_type = (
|
| 80 |
+
mask_type # It is a str which may be None,"luminance" or "color"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.img_size = transform_base["img_size"]
|
| 84 |
+
|
| 85 |
+
self.transform_base = transforms.Compose(
|
| 86 |
+
[
|
| 87 |
+
transforms.RandomCrop(
|
| 88 |
+
(self.img_size, self.img_size),
|
| 89 |
+
pad_if_needed=True,
|
| 90 |
+
padding_mode="reflect",
|
| 91 |
+
),
|
| 92 |
+
transforms.RandomHorizontalFlip(),
|
| 93 |
+
# transforms.RandomVerticalFlip(),
|
| 94 |
+
]
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.transform_flare = transforms.Compose(
|
| 98 |
+
[
|
| 99 |
+
transforms.RandomAffine(
|
| 100 |
+
degrees=(0, 360),
|
| 101 |
+
scale=(transform_flare["scale_min"], transform_flare["scale_max"]),
|
| 102 |
+
translate=(
|
| 103 |
+
transform_flare["translate"] / 1440,
|
| 104 |
+
transform_flare["translate"] / 1440,
|
| 105 |
+
),
|
| 106 |
+
shear=(-transform_flare["shear"], transform_flare["shear"]),
|
| 107 |
+
),
|
| 108 |
+
transforms.CenterCrop((self.img_size, self.img_size)),
|
| 109 |
+
transforms.RandomHorizontalFlip(),
|
| 110 |
+
transforms.RandomVerticalFlip(),
|
| 111 |
+
]
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.normalize = transforms.Compose(
|
| 115 |
+
[
|
| 116 |
+
transforms.Normalize([0.5], [0.5]),
|
| 117 |
+
]
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
self.data_ratio = []
|
| 121 |
+
|
| 122 |
+
def lightsource_crop(self, matrix):
|
| 123 |
+
"""Find the largest rectangle of 1s in a binary matrix."""
|
| 124 |
+
|
| 125 |
+
def largestRectangleArea(heights):
|
| 126 |
+
heights.append(0)
|
| 127 |
+
stack = [-1]
|
| 128 |
+
max_area = 0
|
| 129 |
+
max_rectangle = (0, 0, 0, 0) # (area, left, right, height)
|
| 130 |
+
|
| 131 |
+
for i in range(len(heights)):
|
| 132 |
+
while heights[i] < heights[stack[-1]]:
|
| 133 |
+
h = heights[stack.pop()]
|
| 134 |
+
w = i - stack[-1] - 1
|
| 135 |
+
area = h * w
|
| 136 |
+
if area > max_area:
|
| 137 |
+
max_area = area
|
| 138 |
+
max_rectangle = (area, stack[-1] + 1, i - 1, h)
|
| 139 |
+
stack.append(i)
|
| 140 |
+
|
| 141 |
+
heights.pop()
|
| 142 |
+
return max_rectangle
|
| 143 |
+
|
| 144 |
+
max_area = 0
|
| 145 |
+
max_rectangle = [0, 0, 0, 0] # (left, right, top, bottom)
|
| 146 |
+
heights = torch.zeros(matrix.shape[1])
|
| 147 |
+
|
| 148 |
+
for row in range(matrix.shape[0]):
|
| 149 |
+
temp = 1 - matrix[row]
|
| 150 |
+
heights = (heights + temp) * temp
|
| 151 |
+
|
| 152 |
+
area, left, right, height = largestRectangleArea(heights.tolist())
|
| 153 |
+
if area > max_area:
|
| 154 |
+
max_area = area
|
| 155 |
+
max_rectangle = [int(left), int(right), int(row - height + 1), int(row)]
|
| 156 |
+
|
| 157 |
+
return torch.tensor(max_rectangle)
|
| 158 |
+
|
| 159 |
+
def __getitem__(self, index):
|
| 160 |
+
# load base image
|
| 161 |
+
img_path = self.data_list[index]
|
| 162 |
+
base_img = Image.open(img_path).convert("RGB")
|
| 163 |
+
|
| 164 |
+
gamma = np.random.uniform(1.8, 2.2)
|
| 165 |
+
to_tensor = transforms.ToTensor()
|
| 166 |
+
adjust_gamma = RandomGammaCorrection(gamma)
|
| 167 |
+
adjust_gamma_reverse = RandomGammaCorrection(1 / gamma)
|
| 168 |
+
color_jitter = transforms.ColorJitter(brightness=(0.8, 3), hue=0.0)
|
| 169 |
+
if self.transform_base is not None:
|
| 170 |
+
base_img = to_tensor(base_img)
|
| 171 |
+
base_img = adjust_gamma(base_img)
|
| 172 |
+
base_img = self.transform_base(base_img)
|
| 173 |
+
else:
|
| 174 |
+
base_img = to_tensor(base_img)
|
| 175 |
+
base_img = adjust_gamma(base_img)
|
| 176 |
+
sigma_chi = 0.01 * np.random.chisquare(df=1)
|
| 177 |
+
base_img = Normal(base_img, sigma_chi).sample()
|
| 178 |
+
gain = np.random.uniform(0.5, 1.2)
|
| 179 |
+
flare_DC_offset = np.random.uniform(-0.02, 0.02)
|
| 180 |
+
base_img = gain * base_img
|
| 181 |
+
base_img = torch.clamp(base_img, min=0, max=1)
|
| 182 |
+
|
| 183 |
+
choice_dataset = random.choices(
|
| 184 |
+
[i for i in range(len(self.flare_list))], self.data_ratio
|
| 185 |
+
)[0]
|
| 186 |
+
choice_index = random.randint(0, len(self.flare_list[choice_dataset]) - 1)
|
| 187 |
+
|
| 188 |
+
# load flare and light source image
|
| 189 |
+
if self.light_flag:
|
| 190 |
+
assert len(self.flare_list) == len(
|
| 191 |
+
self.light_list
|
| 192 |
+
), "Error, number of light source and flares dataset no match!"
|
| 193 |
+
for i in range(len(self.flare_list)):
|
| 194 |
+
assert len(self.flare_list[i]) == len(
|
| 195 |
+
self.light_list[i]
|
| 196 |
+
), f"Error, number of light source and flares no match in {i} dataset!"
|
| 197 |
+
flare_path = self.flare_list[choice_dataset][choice_index]
|
| 198 |
+
light_path = self.light_list[choice_dataset][choice_index]
|
| 199 |
+
light_img = Image.open(light_path).convert("RGB")
|
| 200 |
+
light_img = to_tensor(light_img)
|
| 201 |
+
light_img = adjust_gamma(light_img)
|
| 202 |
+
else:
|
| 203 |
+
flare_path = self.flare_list[choice_dataset][choice_index]
|
| 204 |
+
flare_img = Image.open(flare_path).convert("RGB")
|
| 205 |
+
if self.reflective_flag:
|
| 206 |
+
reflective_path_list = self.reflective_list[choice_dataset]
|
| 207 |
+
if len(reflective_path_list) != 0:
|
| 208 |
+
reflective_path = random.choice(reflective_path_list)
|
| 209 |
+
reflective_img = Image.open(reflective_path).convert("RGB")
|
| 210 |
+
else:
|
| 211 |
+
reflective_img = None
|
| 212 |
+
|
| 213 |
+
flare_img = to_tensor(flare_img)
|
| 214 |
+
flare_img = adjust_gamma(flare_img)
|
| 215 |
+
|
| 216 |
+
if self.reflective_flag and reflective_img is not None:
|
| 217 |
+
reflective_img = to_tensor(reflective_img)
|
| 218 |
+
reflective_img = adjust_gamma(reflective_img)
|
| 219 |
+
flare_img = torch.clamp(flare_img + reflective_img, min=0, max=1)
|
| 220 |
+
|
| 221 |
+
flare_img = remove_background(flare_img)
|
| 222 |
+
|
| 223 |
+
if self.transform_flare is not None:
|
| 224 |
+
if self.light_flag:
|
| 225 |
+
flare_merge = torch.cat((flare_img, light_img), dim=0)
|
| 226 |
+
flare_merge = self.transform_flare(flare_merge)
|
| 227 |
+
else:
|
| 228 |
+
flare_img = self.transform_flare(flare_img)
|
| 229 |
+
|
| 230 |
+
# change color
|
| 231 |
+
if self.light_flag:
|
| 232 |
+
# flare_merge=color_jitter(flare_merge)
|
| 233 |
+
flare_img, light_img = torch.split(flare_merge, 3, dim=0)
|
| 234 |
+
else:
|
| 235 |
+
flare_img = color_jitter(flare_img)
|
| 236 |
+
|
| 237 |
+
# flare blur
|
| 238 |
+
blur_transform = transforms.GaussianBlur(21, sigma=(0.1, 3.0))
|
| 239 |
+
flare_img = blur_transform(flare_img)
|
| 240 |
+
# flare_img=flare_img+flare_DC_offset
|
| 241 |
+
flare_img = torch.clamp(flare_img, min=0, max=1)
|
| 242 |
+
|
| 243 |
+
# merge image
|
| 244 |
+
merge_img = flare_img + base_img
|
| 245 |
+
merge_img = torch.clamp(merge_img, min=0, max=1)
|
| 246 |
+
if self.light_flag:
|
| 247 |
+
base_img = base_img + light_img
|
| 248 |
+
base_img = torch.clamp(base_img, min=0, max=1)
|
| 249 |
+
flare_img = flare_img - light_img
|
| 250 |
+
flare_img = torch.clamp(flare_img, min=0, max=1)
|
| 251 |
+
|
| 252 |
+
flare_mask = None
|
| 253 |
+
if self.mask_type == None:
|
| 254 |
+
return {
|
| 255 |
+
"gt": adjust_gamma_reverse(base_img),
|
| 256 |
+
"flare": adjust_gamma_reverse(flare_img),
|
| 257 |
+
"lq": adjust_gamma_reverse(merge_img),
|
| 258 |
+
"gamma": gamma,
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
elif self.mask_type == "luminance":
|
| 262 |
+
# calculate mask (the mask is 3 channel)
|
| 263 |
+
one = torch.ones_like(base_img)
|
| 264 |
+
zero = torch.zeros_like(base_img)
|
| 265 |
+
|
| 266 |
+
luminance = 0.3 * flare_img[0] + 0.59 * flare_img[1] + 0.11 * flare_img[2]
|
| 267 |
+
threshold_value = 0.99**gamma
|
| 268 |
+
flare_mask = torch.where(luminance > threshold_value, one, zero)
|
| 269 |
+
|
| 270 |
+
elif self.mask_type == "color":
|
| 271 |
+
one = torch.ones_like(base_img)
|
| 272 |
+
zero = torch.zeros_like(base_img)
|
| 273 |
+
|
| 274 |
+
threshold_value = 0.99**gamma
|
| 275 |
+
flare_mask = torch.where(merge_img > threshold_value, one, zero)
|
| 276 |
+
|
| 277 |
+
elif self.mask_type == "flare":
|
| 278 |
+
one = torch.ones_like(base_img)
|
| 279 |
+
zero = torch.zeros_like(base_img)
|
| 280 |
+
|
| 281 |
+
threshold_value = 0.7**gamma
|
| 282 |
+
flare_mask = torch.where(flare_img > threshold_value, one, zero)
|
| 283 |
+
|
| 284 |
+
elif self.mask_type == "light":
|
| 285 |
+
# Depreciated: we dont need light mask anymore
|
| 286 |
+
one = torch.ones_like(base_img)
|
| 287 |
+
zero = torch.zeros_like(base_img)
|
| 288 |
+
|
| 289 |
+
luminance = 0.3 * light_img[0] + 0.59 * light_img[1] + 0.11 * light_img[2]
|
| 290 |
+
threshold_value = 0.01
|
| 291 |
+
flare_mask = torch.where(luminance > threshold_value, one, zero)
|
| 292 |
+
|
| 293 |
+
light_source_cond = torch.zeros_like(flare_mask[0])
|
| 294 |
+
light_source_cond = (flare_mask[0] + flare_mask[1] + flare_mask[2]) > 0
|
| 295 |
+
light_source_cond = light_source_cond.float()
|
| 296 |
+
light_source_cond = torch.repeat_interleave(
|
| 297 |
+
light_source_cond[None, ...], 3, dim=0
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# box = self.crop(light_source_cond[0])
|
| 301 |
+
box = self.lightsource_crop(light_source_cond[0])
|
| 302 |
+
|
| 303 |
+
# random int between -15 ~ 15
|
| 304 |
+
margin = random.randint(-15, 15)
|
| 305 |
+
|
| 306 |
+
if box[0] - margin >= 0:
|
| 307 |
+
box[0] -= margin
|
| 308 |
+
if box[1] + margin < self.img_size:
|
| 309 |
+
box[1] += margin
|
| 310 |
+
if box[2] - margin >= 0:
|
| 311 |
+
box[2] -= margin
|
| 312 |
+
if box[3] + margin < self.img_size:
|
| 313 |
+
box[3] += margin
|
| 314 |
+
|
| 315 |
+
top, bottom, left, right = box[2], box[3], box[0], box[1]
|
| 316 |
+
|
| 317 |
+
merge_img = adjust_gamma_reverse(merge_img)
|
| 318 |
+
|
| 319 |
+
cropped_mask = torch.ones((self.img_size, self.img_size))
|
| 320 |
+
cropped_mask[top : bottom + 1, left : right + 1] = False
|
| 321 |
+
cropped_mask = torch.repeat_interleave(cropped_mask[None, ...], 1, dim=0)
|
| 322 |
+
|
| 323 |
+
channel3_mask = cropped_mask.repeat(3, 1, 1)
|
| 324 |
+
masked_img = merge_img * (1 - channel3_mask)
|
| 325 |
+
masked_img[channel3_mask == 1] = 0.5
|
| 326 |
+
|
| 327 |
+
return {
|
| 328 |
+
# add
|
| 329 |
+
"pixel_values": self.normalize(merge_img),
|
| 330 |
+
"masks": cropped_mask,
|
| 331 |
+
"masked_images": self.normalize(masked_img),
|
| 332 |
+
"conditioning_pixel_values": light_source_cond,
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
def __len__(self):
|
| 336 |
+
return len(self.data_list)
|
| 337 |
+
|
| 338 |
+
def load_scattering_flare(self, flare_name, flare_path):
|
| 339 |
+
flare_list = []
|
| 340 |
+
[flare_list.extend(glob.glob(flare_path + "/*." + e)) for e in self.ext]
|
| 341 |
+
flare_list = sorted(flare_list)
|
| 342 |
+
self.flare_name_list.append(flare_name)
|
| 343 |
+
self.flare_dict[flare_name] = flare_list
|
| 344 |
+
self.flare_list.append(flare_list)
|
| 345 |
+
len_flare_list = len(self.flare_dict[flare_name])
|
| 346 |
+
if len_flare_list == 0:
|
| 347 |
+
print("ERROR: scattering flare images are not loaded properly")
|
| 348 |
+
else:
|
| 349 |
+
print(
|
| 350 |
+
"Scattering Flare Image:",
|
| 351 |
+
flare_name,
|
| 352 |
+
" is loaded successfully with examples",
|
| 353 |
+
str(len_flare_list),
|
| 354 |
+
)
|
| 355 |
+
# print("Now we have", len(self.flare_list), "scattering flare images")
|
| 356 |
+
|
| 357 |
+
def load_light_source(self, light_name, light_path):
|
| 358 |
+
# The number of the light source images should match the number of scattering flares
|
| 359 |
+
light_list = []
|
| 360 |
+
[light_list.extend(glob.glob(light_path + "/*." + e)) for e in self.ext]
|
| 361 |
+
light_list = sorted(light_list)
|
| 362 |
+
self.flare_name_list.append(light_name)
|
| 363 |
+
self.light_dict[light_name] = light_list
|
| 364 |
+
self.light_list.append(light_list)
|
| 365 |
+
len_light_list = len(self.light_dict[light_name])
|
| 366 |
+
|
| 367 |
+
if len_light_list == 0:
|
| 368 |
+
print("ERROR: Light Source images are not loaded properly")
|
| 369 |
+
else:
|
| 370 |
+
self.light_flag = True
|
| 371 |
+
print(
|
| 372 |
+
"Light Source Image:",
|
| 373 |
+
light_name,
|
| 374 |
+
" is loaded successfully with examples",
|
| 375 |
+
str(len_light_list),
|
| 376 |
+
)
|
| 377 |
+
# print("Now we have", len(self.light_list), "light source images")
|
| 378 |
+
|
| 379 |
+
def load_reflective_flare(self, reflective_name, reflective_path):
|
| 380 |
+
if reflective_path is None:
|
| 381 |
+
reflective_list = []
|
| 382 |
+
else:
|
| 383 |
+
reflective_list = []
|
| 384 |
+
[
|
| 385 |
+
reflective_list.extend(glob.glob(reflective_path + "/*." + e))
|
| 386 |
+
for e in self.ext
|
| 387 |
+
]
|
| 388 |
+
reflective_list = sorted(reflective_list)
|
| 389 |
+
self.reflective_name_list.append(reflective_name)
|
| 390 |
+
self.reflective_dict[reflective_name] = reflective_list
|
| 391 |
+
self.reflective_list.append(reflective_list)
|
| 392 |
+
len_reflective_list = len(self.reflective_dict[reflective_name])
|
| 393 |
+
if len_reflective_list == 0 and reflective_path is not None:
|
| 394 |
+
print("ERROR: reflective flare images are not loaded properly")
|
| 395 |
+
else:
|
| 396 |
+
self.reflective_flag = True
|
| 397 |
+
print(
|
| 398 |
+
"Reflective Flare Image:",
|
| 399 |
+
reflective_name,
|
| 400 |
+
" is loaded successfully with examples",
|
| 401 |
+
str(len_reflective_list),
|
| 402 |
+
)
|
| 403 |
+
# print("Now we have", len(self.reflective_list), "refelctive flare images")
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
class Flare7kpp_Pair_Loader(Flare_Image_Loader):
|
| 407 |
+
def __init__(self, config):
|
| 408 |
+
Flare_Image_Loader.__init__(
|
| 409 |
+
self,
|
| 410 |
+
config["image_path"],
|
| 411 |
+
config["transform_base"],
|
| 412 |
+
config["transform_flare"],
|
| 413 |
+
config["mask_type"],
|
| 414 |
+
)
|
| 415 |
+
scattering_dict = config["scattering_dict"]
|
| 416 |
+
reflective_dict = config["reflective_dict"]
|
| 417 |
+
light_dict = config["light_dict"]
|
| 418 |
+
|
| 419 |
+
# defualt not use light mask if opt['use_light_mask'] is not declared
|
| 420 |
+
if "data_ratio" not in config or len(config["data_ratio"]) == 0:
|
| 421 |
+
self.data_ratio = [1] * len(scattering_dict)
|
| 422 |
+
else:
|
| 423 |
+
self.data_ratio = config["data_ratio"]
|
| 424 |
+
|
| 425 |
+
if len(scattering_dict) != 0:
|
| 426 |
+
for key in scattering_dict.keys():
|
| 427 |
+
self.load_scattering_flare(key, scattering_dict[key])
|
| 428 |
+
if len(reflective_dict) != 0:
|
| 429 |
+
for key in reflective_dict.keys():
|
| 430 |
+
self.load_reflective_flare(key, reflective_dict[key])
|
| 431 |
+
if len(light_dict) != 0:
|
| 432 |
+
for key in light_dict.keys():
|
| 433 |
+
self.load_light_source(key, light_dict[key])
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class Lightsource_Regress_Loader(Flare7kpp_Pair_Loader):
|
| 437 |
+
def __init__(self, config, num_lights=4):
|
| 438 |
+
Flare7kpp_Pair_Loader.__init__(self, config)
|
| 439 |
+
self.transform_flare = transforms.Compose(
|
| 440 |
+
[
|
| 441 |
+
transforms.RandomAffine(
|
| 442 |
+
degrees=(0, 360),
|
| 443 |
+
scale=(
|
| 444 |
+
config["transform_flare"]["scale_min"],
|
| 445 |
+
config["transform_flare"]["scale_max"],
|
| 446 |
+
),
|
| 447 |
+
shear=(
|
| 448 |
+
-config["transform_flare"]["shear"],
|
| 449 |
+
config["transform_flare"]["shear"],
|
| 450 |
+
),
|
| 451 |
+
),
|
| 452 |
+
# transforms.CenterCrop((self.img_size, self.img_size)),
|
| 453 |
+
]
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
self.mask_type = "light"
|
| 457 |
+
self.num_lights = num_lights
|
| 458 |
+
|
| 459 |
+
def __getitem__(self, index):
|
| 460 |
+
# load base image
|
| 461 |
+
img_path = self.data_list[index]
|
| 462 |
+
base_img = Image.open(img_path).convert("RGB")
|
| 463 |
+
|
| 464 |
+
gamma = np.random.uniform(1.8, 2.2)
|
| 465 |
+
to_tensor = transforms.ToTensor()
|
| 466 |
+
adjust_gamma = RandomGammaCorrection(gamma)
|
| 467 |
+
adjust_gamma_reverse = RandomGammaCorrection(1 / gamma)
|
| 468 |
+
color_jitter = transforms.ColorJitter(brightness=(0.8, 3), hue=0.0)
|
| 469 |
+
|
| 470 |
+
base_img = to_tensor(base_img)
|
| 471 |
+
base_img = adjust_gamma(base_img)
|
| 472 |
+
if self.transform_base is not None:
|
| 473 |
+
base_img = self.transform_base(base_img)
|
| 474 |
+
|
| 475 |
+
sigma_chi = 0.01 * np.random.chisquare(df=1)
|
| 476 |
+
base_img = Normal(base_img, sigma_chi).sample()
|
| 477 |
+
gain = np.random.uniform(0.5, 1.2)
|
| 478 |
+
base_img = gain * base_img
|
| 479 |
+
base_img = torch.clamp(base_img, min=0, max=1)
|
| 480 |
+
|
| 481 |
+
# init flare and light imgs
|
| 482 |
+
flare_imgs = []
|
| 483 |
+
light_imgs = []
|
| 484 |
+
position = [
|
| 485 |
+
[[-224, 0], [-224, 0]],
|
| 486 |
+
[[-224, 0], [0, 224]],
|
| 487 |
+
[[0, 224], [-224, 0]],
|
| 488 |
+
[[0, 224], [0, 224]],
|
| 489 |
+
]
|
| 490 |
+
axis = random.sample(range(4), 4)
|
| 491 |
+
axis[-1] = axis[0]
|
| 492 |
+
flare_nums = int(
|
| 493 |
+
random.random() * self.num_lights + 1
|
| 494 |
+
) # random number of flares from 1 to 4
|
| 495 |
+
|
| 496 |
+
for fn in range(flare_nums):
|
| 497 |
+
choice_dataset = random.choices(
|
| 498 |
+
[i for i in range(len(self.flare_list))], self.data_ratio
|
| 499 |
+
)[0]
|
| 500 |
+
choice_index = random.randint(0, len(self.flare_list[choice_dataset]) - 1)
|
| 501 |
+
|
| 502 |
+
flare_path = self.flare_list[choice_dataset][choice_index]
|
| 503 |
+
flare_img = Image.open(flare_path).convert("RGB")
|
| 504 |
+
flare_img = to_tensor(flare_img)
|
| 505 |
+
flare_img = adjust_gamma(flare_img)
|
| 506 |
+
flare_img = remove_background(flare_img)
|
| 507 |
+
|
| 508 |
+
if self.light_flag:
|
| 509 |
+
light_path = self.light_list[choice_dataset][choice_index]
|
| 510 |
+
light_img = Image.open(light_path).convert("RGB")
|
| 511 |
+
light_img = to_tensor(light_img)
|
| 512 |
+
light_img = adjust_gamma(light_img)
|
| 513 |
+
|
| 514 |
+
if self.transform_flare is not None:
|
| 515 |
+
if self.light_flag:
|
| 516 |
+
flare_merge = torch.cat((flare_img, light_img), dim=0)
|
| 517 |
+
|
| 518 |
+
if flare_nums == 1:
|
| 519 |
+
dx = random.randint(-224, 224)
|
| 520 |
+
dy = random.randint(-224, 224)
|
| 521 |
+
else:
|
| 522 |
+
dx = random.randint(
|
| 523 |
+
position[axis[fn]][0][0], position[axis[fn]][0][1]
|
| 524 |
+
)
|
| 525 |
+
dy = random.randint(
|
| 526 |
+
position[axis[fn]][1][0], position[axis[fn]][1][1]
|
| 527 |
+
)
|
| 528 |
+
if -160 < dx < 160 and -160 < dy < 160:
|
| 529 |
+
if random.random() < 0.5:
|
| 530 |
+
dx = 160 if dx > 0 else -160
|
| 531 |
+
else:
|
| 532 |
+
dy = 160 if dy > 0 else -160
|
| 533 |
+
|
| 534 |
+
flare_merge = self.transform_flare(flare_merge)
|
| 535 |
+
flare_merge = TF.affine(
|
| 536 |
+
flare_merge, angle=0, translate=(dx, dy), scale=1.0, shear=0
|
| 537 |
+
)
|
| 538 |
+
flare_merge = TF.center_crop(
|
| 539 |
+
flare_merge, (self.img_size, self.img_size)
|
| 540 |
+
)
|
| 541 |
+
else:
|
| 542 |
+
flare_img = self.transform_flare(flare_img)
|
| 543 |
+
|
| 544 |
+
# change color
|
| 545 |
+
if self.light_flag:
|
| 546 |
+
flare_img, light_img = torch.split(flare_merge, 3, dim=0)
|
| 547 |
+
else:
|
| 548 |
+
flare_img = color_jitter(flare_img)
|
| 549 |
+
|
| 550 |
+
flare_imgs.append(flare_img)
|
| 551 |
+
if self.light_flag:
|
| 552 |
+
light_img = torch.clamp(light_img, min=0, max=1)
|
| 553 |
+
light_imgs.append(light_img)
|
| 554 |
+
|
| 555 |
+
flare_img = torch.sum(torch.stack(flare_imgs), dim=0)
|
| 556 |
+
flare_img = torch.clamp(flare_img, min=0, max=1)
|
| 557 |
+
|
| 558 |
+
# flare blur
|
| 559 |
+
blur_transform = transforms.GaussianBlur(21, sigma=(0.1, 3.0))
|
| 560 |
+
flare_img = blur_transform(flare_img)
|
| 561 |
+
flare_img = torch.clamp(flare_img, min=0, max=1)
|
| 562 |
+
|
| 563 |
+
merge_img = torch.clamp(flare_img + base_img, min=0, max=1)
|
| 564 |
+
|
| 565 |
+
if self.light_flag:
|
| 566 |
+
light_img = torch.sum(torch.stack(light_imgs), dim=0)
|
| 567 |
+
light_img = torch.clamp(light_img, min=0, max=1)
|
| 568 |
+
base_img = torch.clamp(base_img + light_img, min=0, max=1)
|
| 569 |
+
flare_img = torch.clamp(flare_img - light_img, min=0, max=1)
|
| 570 |
+
|
| 571 |
+
flare_mask = None
|
| 572 |
+
if self.mask_type == None:
|
| 573 |
+
return {
|
| 574 |
+
"gt": adjust_gamma_reverse(base_img),
|
| 575 |
+
"flare": adjust_gamma_reverse(flare_img),
|
| 576 |
+
"lq": adjust_gamma_reverse(merge_img),
|
| 577 |
+
"gamma": gamma,
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
elif self.mask_type == "light":
|
| 581 |
+
one = torch.ones_like(base_img)
|
| 582 |
+
zero = torch.zeros_like(base_img)
|
| 583 |
+
threshold_value = 0.01
|
| 584 |
+
|
| 585 |
+
# flare_masks_list = []
|
| 586 |
+
XYRs = torch.zeros((self.num_lights, 4))
|
| 587 |
+
for i in range(flare_nums):
|
| 588 |
+
luminance = (
|
| 589 |
+
0.3 * light_imgs[i][0]
|
| 590 |
+
+ 0.59 * light_imgs[i][1]
|
| 591 |
+
+ 0.11 * light_imgs[i][2]
|
| 592 |
+
)
|
| 593 |
+
flare_mask = torch.where(luminance > threshold_value, one, zero)
|
| 594 |
+
|
| 595 |
+
light_source_cond = (flare_mask.sum(dim=0) > 0).float()
|
| 596 |
+
|
| 597 |
+
x, y, r = self.find_circle_properties(light_source_cond, i)
|
| 598 |
+
XYRs[i] = torch.tensor([x, y, r, 1.0])
|
| 599 |
+
|
| 600 |
+
XYRs[:, :3] = XYRs[:, :3] / self.img_size
|
| 601 |
+
|
| 602 |
+
luminance = 0.3 * light_img[0] + 0.59 * light_img[1] + 0.11 * light_img[2]
|
| 603 |
+
flare_mask = torch.where(luminance > threshold_value, one, zero)
|
| 604 |
+
|
| 605 |
+
light_source_cond = (flare_mask.sum(dim=0) > 0).float()
|
| 606 |
+
|
| 607 |
+
light_source_cond = torch.repeat_interleave(
|
| 608 |
+
light_source_cond[None, ...], 1, dim=0
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
# box = self.crop(light_source_cond[0])
|
| 612 |
+
box = self.lightsource_crop(light_source_cond[0])
|
| 613 |
+
|
| 614 |
+
# random int between 0 ~ 15
|
| 615 |
+
margin = random.randint(0, 15)
|
| 616 |
+
if box[0] - margin >= 0:
|
| 617 |
+
box[0] -= margin
|
| 618 |
+
if box[1] + margin < self.img_size:
|
| 619 |
+
box[1] += margin
|
| 620 |
+
if box[2] - margin >= 0:
|
| 621 |
+
box[2] -= margin
|
| 622 |
+
if box[3] + margin < self.img_size:
|
| 623 |
+
box[3] += margin
|
| 624 |
+
|
| 625 |
+
top, bottom, left, right = box[2], box[3], box[0], box[1]
|
| 626 |
+
|
| 627 |
+
merge_img = adjust_gamma_reverse(merge_img)
|
| 628 |
+
|
| 629 |
+
cropped_mask = torch.full(
|
| 630 |
+
(self.img_size, self.img_size), True, dtype=torch.bool
|
| 631 |
+
)
|
| 632 |
+
cropped_mask[top : bottom + 1, left : right + 1] = False
|
| 633 |
+
channel3_mask = cropped_mask.unsqueeze(0).expand(3, -1, -1)
|
| 634 |
+
|
| 635 |
+
masked_img = merge_img * (1 - channel3_mask.float())
|
| 636 |
+
masked_img[channel3_mask] = 0.5
|
| 637 |
+
|
| 638 |
+
return {
|
| 639 |
+
# add
|
| 640 |
+
"input": self.normalize(masked_img), # normalize to [-1, 1]
|
| 641 |
+
"light_masks": light_source_cond,
|
| 642 |
+
"xyrs": XYRs,
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
def find_circle_properties(self, mask, i, method="minEnclosingCircle"):
|
| 646 |
+
"""
|
| 647 |
+
Find the properties of the light source circle in the mask.
|
| 648 |
+
"""
|
| 649 |
+
|
| 650 |
+
_mask = (mask.numpy() * 255).astype(np.uint8)
|
| 651 |
+
_, binary_mask = cv2.threshold(_mask, 127, 255, cv2.THRESH_BINARY)
|
| 652 |
+
contours, _ = cv2.findContours(
|
| 653 |
+
binary_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
if len(contours) == 0:
|
| 657 |
+
return 0.0, 0.0, 0.0
|
| 658 |
+
|
| 659 |
+
largest_contour = max(contours, key=cv2.contourArea)
|
| 660 |
+
|
| 661 |
+
if method == "minEnclosingCircle":
|
| 662 |
+
(x, y), radius = cv2.minEnclosingCircle(largest_contour)
|
| 663 |
+
|
| 664 |
+
elif method == "area_based":
|
| 665 |
+
M = cv2.moments(largest_contour)
|
| 666 |
+
if M["m00"] == 0: # if the contour is too small
|
| 667 |
+
return 0.0, 0.0, 0.0
|
| 668 |
+
|
| 669 |
+
x = M["m10"] / M["m00"]
|
| 670 |
+
y = M["m01"] / M["m00"]
|
| 671 |
+
area = cv2.contourArea(largest_contour)
|
| 672 |
+
radius = np.sqrt(area / np.pi)
|
| 673 |
+
|
| 674 |
+
# # draw
|
| 675 |
+
# cv2.circle(_mask, (int(x), int(y)), int(radius), 128, 2)
|
| 676 |
+
# cv2.imwrite(f"mask_{i}.png", _mask)
|
| 677 |
+
|
| 678 |
+
return x, y, radius
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
class Lightsource_3Maps_Loader(Lightsource_Regress_Loader):
|
| 682 |
+
def __init__(self, config, num_lights=4):
|
| 683 |
+
Lightsource_Regress_Loader.__init__(self, config, num_lights=num_lights)
|
| 684 |
+
|
| 685 |
+
def build_gt_maps(self, coords, radii, H, W, kappa=0.4):
|
| 686 |
+
yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij")
|
| 687 |
+
prob_gt = torch.zeros((H, W))
|
| 688 |
+
rad_gt = torch.zeros((H, W))
|
| 689 |
+
|
| 690 |
+
eps = 1e-6
|
| 691 |
+
for x_i, y_i, r_i in zip(coords[:, 0], coords[:, 1], radii):
|
| 692 |
+
if r_i < 1.0:
|
| 693 |
+
continue
|
| 694 |
+
|
| 695 |
+
sigma = kappa * r_i
|
| 696 |
+
g = torch.exp(-((xx - x_i) ** 2 + (yy - y_i) ** 2) / (2 * sigma**2))
|
| 697 |
+
g_prime = torch.exp(
|
| 698 |
+
-((xx - x_i) ** 2 + (yy - y_i) ** 2) / (2 * (sigma / 1.414) ** 2)
|
| 699 |
+
)
|
| 700 |
+
prob_gt = torch.maximum(prob_gt, g)
|
| 701 |
+
rad_gt = torch.maximum(rad_gt, g_prime * r_i)
|
| 702 |
+
|
| 703 |
+
rad_gt = rad_gt / (prob_gt + eps)
|
| 704 |
+
return prob_gt, rad_gt
|
| 705 |
+
|
| 706 |
+
def __getitem__(self, index):
|
| 707 |
+
# load base image
|
| 708 |
+
img_path = self.data_list[index]
|
| 709 |
+
base_img = Image.open(img_path).convert("RGB")
|
| 710 |
+
|
| 711 |
+
gamma = np.random.uniform(1.8, 2.2)
|
| 712 |
+
to_tensor = transforms.ToTensor()
|
| 713 |
+
adjust_gamma = RandomGammaCorrection(gamma)
|
| 714 |
+
adjust_gamma_reverse = RandomGammaCorrection(1 / gamma)
|
| 715 |
+
color_jitter = transforms.ColorJitter(brightness=(0.8, 3), hue=0.0)
|
| 716 |
+
|
| 717 |
+
base_img = to_tensor(base_img)
|
| 718 |
+
base_img = adjust_gamma(base_img)
|
| 719 |
+
if self.transform_base is not None:
|
| 720 |
+
base_img = self.transform_base(base_img)
|
| 721 |
+
|
| 722 |
+
sigma_chi = 0.01 * np.random.chisquare(df=1)
|
| 723 |
+
base_img = Normal(base_img, sigma_chi).sample()
|
| 724 |
+
gain = np.random.uniform(0.5, 1.2)
|
| 725 |
+
base_img = gain * base_img
|
| 726 |
+
base_img = torch.clamp(base_img, min=0, max=1)
|
| 727 |
+
|
| 728 |
+
# init flare and light imgs
|
| 729 |
+
flare_imgs = []
|
| 730 |
+
light_imgs = []
|
| 731 |
+
position = [
|
| 732 |
+
[[-224, 0], [-224, 0]],
|
| 733 |
+
[[-224, 0], [0, 224]],
|
| 734 |
+
[[0, 224], [-224, 0]],
|
| 735 |
+
[[0, 224], [0, 224]],
|
| 736 |
+
]
|
| 737 |
+
axis = random.sample(range(4), 4)
|
| 738 |
+
axis[-1] = axis[0]
|
| 739 |
+
flare_nums = int(
|
| 740 |
+
random.random() * self.num_lights + 1
|
| 741 |
+
) # random number of flares from 1 to 4
|
| 742 |
+
|
| 743 |
+
for fn in range(flare_nums):
|
| 744 |
+
choice_dataset = random.choices(
|
| 745 |
+
[i for i in range(len(self.flare_list))], self.data_ratio
|
| 746 |
+
)[0]
|
| 747 |
+
choice_index = random.randint(0, len(self.flare_list[choice_dataset]) - 1)
|
| 748 |
+
|
| 749 |
+
flare_path = self.flare_list[choice_dataset][choice_index]
|
| 750 |
+
flare_img = Image.open(flare_path).convert("RGB")
|
| 751 |
+
flare_img = to_tensor(flare_img)
|
| 752 |
+
flare_img = adjust_gamma(flare_img)
|
| 753 |
+
flare_img = remove_background(flare_img)
|
| 754 |
+
|
| 755 |
+
if self.light_flag:
|
| 756 |
+
light_path = self.light_list[choice_dataset][choice_index]
|
| 757 |
+
light_img = Image.open(light_path).convert("RGB")
|
| 758 |
+
light_img = to_tensor(light_img)
|
| 759 |
+
light_img = adjust_gamma(light_img)
|
| 760 |
+
|
| 761 |
+
if self.transform_flare is not None:
|
| 762 |
+
if self.light_flag:
|
| 763 |
+
flare_merge = torch.cat((flare_img, light_img), dim=0)
|
| 764 |
+
|
| 765 |
+
if flare_nums == 1:
|
| 766 |
+
dx = random.randint(-224, 224)
|
| 767 |
+
dy = random.randint(-224, 224)
|
| 768 |
+
else:
|
| 769 |
+
dx = random.randint(
|
| 770 |
+
position[axis[fn]][0][0], position[axis[fn]][0][1]
|
| 771 |
+
)
|
| 772 |
+
dy = random.randint(
|
| 773 |
+
position[axis[fn]][1][0], position[axis[fn]][1][1]
|
| 774 |
+
)
|
| 775 |
+
if -160 < dx < 160 and -160 < dy < 160:
|
| 776 |
+
if random.random() < 0.5:
|
| 777 |
+
dx = 160 if dx > 0 else -160
|
| 778 |
+
else:
|
| 779 |
+
dy = 160 if dy > 0 else -160
|
| 780 |
+
|
| 781 |
+
flare_merge = self.transform_flare(flare_merge)
|
| 782 |
+
flare_merge = TF.affine(
|
| 783 |
+
flare_merge, angle=0, translate=(dx, dy), scale=1.0, shear=0
|
| 784 |
+
)
|
| 785 |
+
flare_merge = TF.center_crop(
|
| 786 |
+
flare_merge, (self.img_size, self.img_size)
|
| 787 |
+
)
|
| 788 |
+
else:
|
| 789 |
+
flare_img = self.transform_flare(flare_img)
|
| 790 |
+
|
| 791 |
+
# change color
|
| 792 |
+
if self.light_flag:
|
| 793 |
+
flare_img, light_img = torch.split(flare_merge, 3, dim=0)
|
| 794 |
+
else:
|
| 795 |
+
flare_img = color_jitter(flare_img)
|
| 796 |
+
|
| 797 |
+
flare_imgs.append(flare_img)
|
| 798 |
+
if self.light_flag:
|
| 799 |
+
light_img = torch.clamp(light_img, min=0, max=1)
|
| 800 |
+
light_imgs.append(light_img)
|
| 801 |
+
|
| 802 |
+
flare_img = torch.sum(torch.stack(flare_imgs), dim=0)
|
| 803 |
+
flare_img = torch.clamp(flare_img, min=0, max=1)
|
| 804 |
+
|
| 805 |
+
# flare blur
|
| 806 |
+
blur_transform = transforms.GaussianBlur(21, sigma=(0.1, 3.0))
|
| 807 |
+
flare_img = blur_transform(flare_img)
|
| 808 |
+
flare_img = torch.clamp(flare_img, min=0, max=1)
|
| 809 |
+
|
| 810 |
+
merge_img = torch.clamp(flare_img + base_img, min=0, max=1)
|
| 811 |
+
|
| 812 |
+
if self.light_flag:
|
| 813 |
+
light_img = torch.sum(torch.stack(light_imgs), dim=0)
|
| 814 |
+
light_img = torch.clamp(light_img, min=0, max=1)
|
| 815 |
+
base_img = torch.clamp(base_img + light_img, min=0, max=1)
|
| 816 |
+
flare_img = torch.clamp(flare_img - light_img, min=0, max=1)
|
| 817 |
+
|
| 818 |
+
flare_mask = None
|
| 819 |
+
if self.mask_type == None:
|
| 820 |
+
return {
|
| 821 |
+
"gt": adjust_gamma_reverse(base_img),
|
| 822 |
+
"flare": adjust_gamma_reverse(flare_img),
|
| 823 |
+
"lq": adjust_gamma_reverse(merge_img),
|
| 824 |
+
"gamma": gamma,
|
| 825 |
+
}
|
| 826 |
+
|
| 827 |
+
elif self.mask_type == "light":
|
| 828 |
+
one = torch.ones_like(base_img)
|
| 829 |
+
zero = torch.zeros_like(base_img)
|
| 830 |
+
threshold_value = 0.01
|
| 831 |
+
|
| 832 |
+
# flare_masks_list = []
|
| 833 |
+
XYRs = torch.zeros((self.num_lights, 4))
|
| 834 |
+
for i in range(flare_nums):
|
| 835 |
+
luminance = (
|
| 836 |
+
0.3 * light_imgs[i][0]
|
| 837 |
+
+ 0.59 * light_imgs[i][1]
|
| 838 |
+
+ 0.11 * light_imgs[i][2]
|
| 839 |
+
)
|
| 840 |
+
flare_mask = torch.where(luminance > threshold_value, one, zero)
|
| 841 |
+
|
| 842 |
+
light_source_cond = (flare_mask.sum(dim=0) > 0).float()
|
| 843 |
+
|
| 844 |
+
x, y, r = self.find_circle_properties(light_source_cond, i)
|
| 845 |
+
XYRs[i] = torch.tensor([x, y, r, 1.0])
|
| 846 |
+
|
| 847 |
+
gt_prob, gt_rad = self.build_gt_maps(
|
| 848 |
+
XYRs[:, :2], XYRs[:, 2], self.img_size, self.img_size
|
| 849 |
+
)
|
| 850 |
+
gt_prob = gt_prob.unsqueeze(0) # shape: (1, H, W)
|
| 851 |
+
gt_rad = gt_rad.unsqueeze(0)
|
| 852 |
+
gt_rad /= self.img_size
|
| 853 |
+
gt_maps = torch.cat((gt_prob, gt_rad), dim=0) # shape: (2, H, W)
|
| 854 |
+
|
| 855 |
+
XYRs[:, :3] = XYRs[:, :3] / self.img_size
|
| 856 |
+
|
| 857 |
+
luminance = 0.3 * light_img[0] + 0.59 * light_img[1] + 0.11 * light_img[2]
|
| 858 |
+
flare_mask = torch.where(luminance > threshold_value, one, zero)
|
| 859 |
+
|
| 860 |
+
light_source_cond = (flare_mask.sum(dim=0) > 0).float()
|
| 861 |
+
|
| 862 |
+
light_source_cond = torch.repeat_interleave(
|
| 863 |
+
light_source_cond[None, ...], 1, dim=0
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
# box = self.crop(light_source_cond[0])
|
| 867 |
+
box = self.lightsource_crop(light_source_cond[0])
|
| 868 |
+
|
| 869 |
+
# random int between 0 ~ 15
|
| 870 |
+
margin = random.randint(0, 15)
|
| 871 |
+
if box[0] - margin >= 0:
|
| 872 |
+
box[0] -= margin
|
| 873 |
+
if box[1] + margin < self.img_size:
|
| 874 |
+
box[1] += margin
|
| 875 |
+
if box[2] - margin >= 0:
|
| 876 |
+
box[2] -= margin
|
| 877 |
+
if box[3] + margin < self.img_size:
|
| 878 |
+
box[3] += margin
|
| 879 |
+
|
| 880 |
+
top, bottom, left, right = box[2], box[3], box[0], box[1]
|
| 881 |
+
|
| 882 |
+
merge_img = adjust_gamma_reverse(merge_img)
|
| 883 |
+
|
| 884 |
+
cropped_mask = torch.full(
|
| 885 |
+
(self.img_size, self.img_size), True, dtype=torch.bool
|
| 886 |
+
)
|
| 887 |
+
cropped_mask[top : bottom + 1, left : right + 1] = False
|
| 888 |
+
channel3_mask = cropped_mask.unsqueeze(0).expand(3, -1, -1)
|
| 889 |
+
|
| 890 |
+
masked_img = merge_img * (1 - channel3_mask.float())
|
| 891 |
+
masked_img[channel3_mask] = 0.5
|
| 892 |
+
|
| 893 |
+
return {
|
| 894 |
+
# add
|
| 895 |
+
"input": self.normalize(masked_img), # normalize to [-1, 1]
|
| 896 |
+
"light_masks": light_source_cond,
|
| 897 |
+
"xyrs": gt_maps,
|
| 898 |
+
}
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
class TestImageLoader(Dataset):
|
| 902 |
+
def __init__(
|
| 903 |
+
self,
|
| 904 |
+
dataroot_gt,
|
| 905 |
+
dataroot_input,
|
| 906 |
+
dataroot_mask,
|
| 907 |
+
margin=0,
|
| 908 |
+
img_size=512,
|
| 909 |
+
noise_matching=False,
|
| 910 |
+
):
|
| 911 |
+
super(TestImageLoader, self).__init__()
|
| 912 |
+
self.gt_folder = dataroot_gt
|
| 913 |
+
self.input_folder = dataroot_input
|
| 914 |
+
self.mask_folder = dataroot_mask
|
| 915 |
+
self.paths = glod_from_folder(
|
| 916 |
+
[self.input_folder, self.gt_folder, self.mask_folder],
|
| 917 |
+
["input", "gt", "mask"],
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
self.margin = margin
|
| 921 |
+
self.img_size = img_size
|
| 922 |
+
self.noise_matching = noise_matching
|
| 923 |
+
|
| 924 |
+
def __len__(self):
|
| 925 |
+
return len(self.paths["input"])
|
| 926 |
+
|
| 927 |
+
def __getitem__(self, index):
|
| 928 |
+
img_name = self.paths["input"][index].split("/")[-1]
|
| 929 |
+
num = img_name.split("_")[1].split(".")[0]
|
| 930 |
+
|
| 931 |
+
# preprocess light source mask
|
| 932 |
+
light_mask = np.array(Image.open(self.paths["mask"][index]))
|
| 933 |
+
tmp_light_mask = np.zeros_like(light_mask[:, :, 0])
|
| 934 |
+
tmp_light_mask[light_mask[:, :, 2] > 0] = 255
|
| 935 |
+
cond = (light_mask[:, :, 0] > 0) & (light_mask[:, :, 1] > 0)
|
| 936 |
+
tmp_light_mask[cond] = 0
|
| 937 |
+
light_mask = tmp_light_mask
|
| 938 |
+
|
| 939 |
+
# img for controlnet input
|
| 940 |
+
control_img = np.repeat(light_mask[:, :, None], 3, axis=2)
|
| 941 |
+
|
| 942 |
+
# crop region
|
| 943 |
+
box = self.lightsource_crop(light_mask)
|
| 944 |
+
|
| 945 |
+
if box[0] - self.margin >= 0:
|
| 946 |
+
box[0] -= self.margin
|
| 947 |
+
if box[1] + self.margin < self.img_size:
|
| 948 |
+
box[1] += self.margin
|
| 949 |
+
if box[2] - self.margin >= 0:
|
| 950 |
+
box[2] -= self.margin
|
| 951 |
+
if box[3] + self.margin < self.img_size:
|
| 952 |
+
box[3] += self.margin
|
| 953 |
+
|
| 954 |
+
# input image to be outpainted
|
| 955 |
+
input_img = np.array(Image.open(self.paths["input"][index]))
|
| 956 |
+
cropped_region = np.ones((self.img_size, self.img_size), dtype=np.uint8)
|
| 957 |
+
cropped_region[box[2] : box[3] + 1, box[0] : box[1] + 1] = 0
|
| 958 |
+
input_img[cropped_region == 1] = 128
|
| 959 |
+
|
| 960 |
+
# image for blip
|
| 961 |
+
blip_img = input_img[box[2] : box[3] + 1, box[0] : box[1] + 1, :]
|
| 962 |
+
|
| 963 |
+
# noise matching
|
| 964 |
+
input_img_matching = None
|
| 965 |
+
if self.noise_matching:
|
| 966 |
+
np_src_img = input_img / 255.0
|
| 967 |
+
np_mask_rgb = np.repeat(cropped_region[:, :, None], 3, axis=2).astype(
|
| 968 |
+
np.float32
|
| 969 |
+
)
|
| 970 |
+
matched_noise = self.get_matched_noise(np_src_img, np_mask_rgb)
|
| 971 |
+
input_img_matching = (matched_noise * 255).astype(np.uint8)
|
| 972 |
+
|
| 973 |
+
# mask image
|
| 974 |
+
mask_img = (cropped_region * 255).astype(np.uint8)
|
| 975 |
+
|
| 976 |
+
return {
|
| 977 |
+
"blip_img": blip_img,
|
| 978 |
+
"input_img": Image.fromarray(input_img),
|
| 979 |
+
"input_img_matching": (
|
| 980 |
+
Image.fromarray(input_img_matching)
|
| 981 |
+
if input_img_matching is not None
|
| 982 |
+
else Image.fromarray(input_img)
|
| 983 |
+
),
|
| 984 |
+
"mask_img": Image.fromarray(mask_img),
|
| 985 |
+
"control_img": Image.fromarray(control_img),
|
| 986 |
+
"box": box,
|
| 987 |
+
"output_name": "output_" + num + ".png",
|
| 988 |
+
}
|
| 989 |
+
|
| 990 |
+
def lightsource_crop(self, matrix):
|
| 991 |
+
"""Find the largest rectangle of 1s in a binary matrix."""
|
| 992 |
+
|
| 993 |
+
def largestRectangleArea(heights):
|
| 994 |
+
heights.append(0)
|
| 995 |
+
stack = [-1]
|
| 996 |
+
max_area = 0
|
| 997 |
+
max_rectangle = (0, 0, 0, 0) # (area, left, right, height)
|
| 998 |
+
for i in range(len(heights)):
|
| 999 |
+
while heights[i] < heights[stack[-1]]:
|
| 1000 |
+
h = heights[stack.pop()]
|
| 1001 |
+
w = i - stack[-1] - 1
|
| 1002 |
+
area = h * w
|
| 1003 |
+
if area > max_area:
|
| 1004 |
+
max_area = area
|
| 1005 |
+
max_rectangle = (area, stack[-1] + 1, i - 1, h)
|
| 1006 |
+
stack.append(i)
|
| 1007 |
+
heights.pop()
|
| 1008 |
+
return max_rectangle
|
| 1009 |
+
|
| 1010 |
+
max_area = 0
|
| 1011 |
+
max_rectangle = [0, 0, 0, 0] # (left, right, top, bottom)
|
| 1012 |
+
heights = [0] * len(matrix[0])
|
| 1013 |
+
for row in range(len(matrix)):
|
| 1014 |
+
for i, val in enumerate(matrix[row]):
|
| 1015 |
+
heights[i] = heights[i] + 1 if val == 0 else 0
|
| 1016 |
+
|
| 1017 |
+
area, left, right, height = largestRectangleArea(heights)
|
| 1018 |
+
if area > max_area:
|
| 1019 |
+
max_area = area
|
| 1020 |
+
max_rectangle = [int(left), int(right), int(row - height + 1), int(row)]
|
| 1021 |
+
|
| 1022 |
+
return list(max_rectangle)
|
| 1023 |
+
|
| 1024 |
+
# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
|
| 1025 |
+
def get_matched_noise(
|
| 1026 |
+
self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05
|
| 1027 |
+
):
|
| 1028 |
+
# helper fft routines that keep ortho normalization and auto-shift before and after fft
|
| 1029 |
+
def _fft2(data):
|
| 1030 |
+
if data.ndim > 2: # has channels
|
| 1031 |
+
out_fft = np.zeros(
|
| 1032 |
+
(data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
|
| 1033 |
+
)
|
| 1034 |
+
for c in range(data.shape[2]):
|
| 1035 |
+
c_data = data[:, :, c]
|
| 1036 |
+
out_fft[:, :, c] = np.fft.fft2(
|
| 1037 |
+
np.fft.fftshift(c_data), norm="ortho"
|
| 1038 |
+
)
|
| 1039 |
+
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
|
| 1040 |
+
else: # one channel
|
| 1041 |
+
out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
|
| 1042 |
+
out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
|
| 1043 |
+
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
|
| 1044 |
+
|
| 1045 |
+
return out_fft
|
| 1046 |
+
|
| 1047 |
+
def _ifft2(data):
|
| 1048 |
+
if data.ndim > 2: # has channels
|
| 1049 |
+
out_ifft = np.zeros(
|
| 1050 |
+
(data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
|
| 1051 |
+
)
|
| 1052 |
+
for c in range(data.shape[2]):
|
| 1053 |
+
c_data = data[:, :, c]
|
| 1054 |
+
out_ifft[:, :, c] = np.fft.ifft2(
|
| 1055 |
+
np.fft.fftshift(c_data), norm="ortho"
|
| 1056 |
+
)
|
| 1057 |
+
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
|
| 1058 |
+
else: # one channel
|
| 1059 |
+
out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
|
| 1060 |
+
out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
|
| 1061 |
+
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
|
| 1062 |
+
|
| 1063 |
+
return out_ifft
|
| 1064 |
+
|
| 1065 |
+
def _get_gaussian_window(width, height, std=3.14, mode=0):
|
| 1066 |
+
window_scale_x = float(width / min(width, height))
|
| 1067 |
+
window_scale_y = float(height / min(width, height))
|
| 1068 |
+
|
| 1069 |
+
window = np.zeros((width, height))
|
| 1070 |
+
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
|
| 1071 |
+
for y in range(height):
|
| 1072 |
+
fy = (y / height * 2.0 - 1.0) * window_scale_y
|
| 1073 |
+
if mode == 0:
|
| 1074 |
+
window[:, y] = np.exp(-(x**2 + fy**2) * std)
|
| 1075 |
+
else:
|
| 1076 |
+
window[:, y] = (1 / ((x**2 + 1.0) * (fy**2 + 1.0))) ** (
|
| 1077 |
+
std / 3.14
|
| 1078 |
+
) # hey wait a minute that's not gaussian
|
| 1079 |
+
|
| 1080 |
+
return window
|
| 1081 |
+
|
| 1082 |
+
def _get_masked_window_rgb(np_mask_grey, hardness=1.0):
|
| 1083 |
+
np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
|
| 1084 |
+
if hardness != 1.0:
|
| 1085 |
+
hardened = np_mask_grey[:] ** hardness
|
| 1086 |
+
else:
|
| 1087 |
+
hardened = np_mask_grey[:]
|
| 1088 |
+
for c in range(3):
|
| 1089 |
+
np_mask_rgb[:, :, c] = hardened[:]
|
| 1090 |
+
return np_mask_rgb
|
| 1091 |
+
|
| 1092 |
+
width = _np_src_image.shape[0]
|
| 1093 |
+
height = _np_src_image.shape[1]
|
| 1094 |
+
num_channels = _np_src_image.shape[2]
|
| 1095 |
+
|
| 1096 |
+
_np_src_image[:] * (1.0 - np_mask_rgb)
|
| 1097 |
+
np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
|
| 1098 |
+
img_mask = np_mask_grey > 1e-6
|
| 1099 |
+
ref_mask = np_mask_grey < 1e-3
|
| 1100 |
+
|
| 1101 |
+
windowed_image = _np_src_image * (1.0 - _get_masked_window_rgb(np_mask_grey))
|
| 1102 |
+
windowed_image /= np.max(windowed_image)
|
| 1103 |
+
windowed_image += (
|
| 1104 |
+
np.average(_np_src_image) * np_mask_rgb
|
| 1105 |
+
) # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
|
| 1106 |
+
|
| 1107 |
+
src_fft = _fft2(windowed_image) # get feature statistics from masked src img
|
| 1108 |
+
src_dist = np.absolute(src_fft)
|
| 1109 |
+
src_phase = src_fft / src_dist
|
| 1110 |
+
|
| 1111 |
+
# create a generator with a static seed to make outpainting deterministic / only follow global seed
|
| 1112 |
+
rng = np.random.default_rng(0)
|
| 1113 |
+
|
| 1114 |
+
noise_window = _get_gaussian_window(
|
| 1115 |
+
width, height, mode=1
|
| 1116 |
+
) # start with simple gaussian noise
|
| 1117 |
+
noise_rgb = rng.random((width, height, num_channels))
|
| 1118 |
+
noise_grey = np.sum(noise_rgb, axis=2) / 3.0
|
| 1119 |
+
noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
|
| 1120 |
+
for c in range(num_channels):
|
| 1121 |
+
noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
|
| 1122 |
+
|
| 1123 |
+
noise_fft = _fft2(noise_rgb)
|
| 1124 |
+
for c in range(num_channels):
|
| 1125 |
+
noise_fft[:, :, c] *= noise_window
|
| 1126 |
+
noise_rgb = np.real(_ifft2(noise_fft))
|
| 1127 |
+
shaped_noise_fft = _fft2(noise_rgb)
|
| 1128 |
+
shaped_noise_fft[:, :, :] = (
|
| 1129 |
+
np.absolute(shaped_noise_fft[:, :, :]) ** 2
|
| 1130 |
+
* (src_dist**noise_q)
|
| 1131 |
+
* src_phase
|
| 1132 |
+
) # perform the actual shaping
|
| 1133 |
+
|
| 1134 |
+
brightness_variation = 0.0 # color_variation # todo: temporarily tying brightness variation to color variation for now
|
| 1135 |
+
contrast_adjusted_np_src = (
|
| 1136 |
+
_np_src_image[:] * (brightness_variation + 1.0) - brightness_variation * 2.0
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
# scikit-image is used for histogram matching, very convenient!
|
| 1140 |
+
shaped_noise = np.real(_ifft2(shaped_noise_fft))
|
| 1141 |
+
shaped_noise -= np.min(shaped_noise)
|
| 1142 |
+
shaped_noise /= np.max(shaped_noise)
|
| 1143 |
+
shaped_noise[img_mask, :] = skimage.exposure.match_histograms(
|
| 1144 |
+
shaped_noise[img_mask, :] ** 1.0,
|
| 1145 |
+
contrast_adjusted_np_src[ref_mask, :],
|
| 1146 |
+
channel_axis=1,
|
| 1147 |
+
)
|
| 1148 |
+
shaped_noise = (
|
| 1149 |
+
_np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
|
| 1150 |
+
)
|
| 1151 |
+
|
| 1152 |
+
matched_noise = shaped_noise[:]
|
| 1153 |
+
|
| 1154 |
+
return np.clip(matched_noise, 0.0, 1.0)
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
class CustomImageLoader(Dataset):
|
| 1158 |
+
def __init__(
|
| 1159 |
+
self, dataroot_input, left_outpaint, right_outpaint, up_outpaint, down_outpaint
|
| 1160 |
+
):
|
| 1161 |
+
self.dataroot_input = dataroot_input
|
| 1162 |
+
self.left_outpaint = left_outpaint
|
| 1163 |
+
self.right_outpaint = right_outpaint
|
| 1164 |
+
self.up_outpaint = up_outpaint
|
| 1165 |
+
self.down_outpaint = down_outpaint
|
| 1166 |
+
|
| 1167 |
+
self.H = 512 - (up_outpaint + down_outpaint)
|
| 1168 |
+
self.W = 512 - (left_outpaint + right_outpaint)
|
| 1169 |
+
self.img_size = 512
|
| 1170 |
+
|
| 1171 |
+
self.img_lists = [
|
| 1172 |
+
os.path.join(dataroot_input, f)
|
| 1173 |
+
for f in os.listdir(dataroot_input)
|
| 1174 |
+
if f.endswith(".png") or f.endswith(".jpg")
|
| 1175 |
+
]
|
| 1176 |
+
|
| 1177 |
+
def __len__(self):
|
| 1178 |
+
return len(self.img_lists)
|
| 1179 |
+
|
| 1180 |
+
def __getitem__(self, index):
|
| 1181 |
+
img_name = self.img_lists[index].split("/")[-1]
|
| 1182 |
+
|
| 1183 |
+
# crop region
|
| 1184 |
+
box = [
|
| 1185 |
+
self.left_outpaint,
|
| 1186 |
+
511 - self.right_outpaint,
|
| 1187 |
+
self.up_outpaint,
|
| 1188 |
+
511 - self.down_outpaint,
|
| 1189 |
+
] # [left, right, top, bottom]
|
| 1190 |
+
|
| 1191 |
+
# box = self.lightsource_crop(light_mask)
|
| 1192 |
+
# if box[0] - self.margin >= 0:
|
| 1193 |
+
# box[0] -= self.margin
|
| 1194 |
+
# if box[1] + self.margin < self.img_size:
|
| 1195 |
+
# box[1] += self.margin
|
| 1196 |
+
# if box[2] - self.margin >= 0:
|
| 1197 |
+
# box[2] -= self.margin
|
| 1198 |
+
# if box[3] + self.margin < self.img_size:
|
| 1199 |
+
# box[3] += self.margin
|
| 1200 |
+
|
| 1201 |
+
# input image to be outpainted
|
| 1202 |
+
input_img = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
|
| 1203 |
+
paste_img = np.array(
|
| 1204 |
+
Image.open(self.img_lists[index]).resize((self.W, self.H), Image.LANCZOS)
|
| 1205 |
+
)
|
| 1206 |
+
input_img[box[2] : box[3] + 1, box[0] : box[1] + 1, :] = paste_img
|
| 1207 |
+
cropped_region = np.ones((self.img_size, self.img_size), dtype=np.uint8)
|
| 1208 |
+
cropped_region[box[2] : box[3] + 1, box[0] : box[1] + 1] = 0
|
| 1209 |
+
input_img[cropped_region == 1] = 128
|
| 1210 |
+
|
| 1211 |
+
# image for blip
|
| 1212 |
+
blip_img = np.array(Image.open(self.img_lists[index]))
|
| 1213 |
+
|
| 1214 |
+
# # noise matching
|
| 1215 |
+
# input_img_matching = None
|
| 1216 |
+
# if self.noise_matching:
|
| 1217 |
+
# np_src_img = input_img / 255.0
|
| 1218 |
+
# np_mask_rgb = np.repeat(cropped_region[:, :, None], 3, axis=2).astype(
|
| 1219 |
+
# np.float32
|
| 1220 |
+
# )
|
| 1221 |
+
# matched_noise = self.get_matched_noise(np_src_img, np_mask_rgb)
|
| 1222 |
+
# input_img_matching = (matched_noise * 255).astype(np.uint8)
|
| 1223 |
+
|
| 1224 |
+
# mask image
|
| 1225 |
+
mask_img = (cropped_region * 255).astype(np.uint8)
|
| 1226 |
+
|
| 1227 |
+
return {
|
| 1228 |
+
"blip_img": blip_img,
|
| 1229 |
+
"input_img": Image.fromarray(input_img),
|
| 1230 |
+
# "input_img": (
|
| 1231 |
+
# Image.fromarray(input_img_matching)
|
| 1232 |
+
# if input_img_matching is not None
|
| 1233 |
+
# else Image.fromarray(input_img)
|
| 1234 |
+
# ),
|
| 1235 |
+
"mask_img": Image.fromarray(mask_img),
|
| 1236 |
+
"box": box,
|
| 1237 |
+
"output_name": img_name,
|
| 1238 |
+
}
|
| 1239 |
+
|
| 1240 |
+
|
| 1241 |
+
|
| 1242 |
+
class HFCustomImageLoader(Dataset):
|
| 1243 |
+
def __init__(
|
| 1244 |
+
self, img_data, left_outpaint=64, right_outpaint=64, up_outpaint=64, down_outpaint=64
|
| 1245 |
+
):
|
| 1246 |
+
self.left_outpaint = left_outpaint
|
| 1247 |
+
self.right_outpaint = right_outpaint
|
| 1248 |
+
self.up_outpaint = up_outpaint
|
| 1249 |
+
self.down_outpaint = down_outpaint
|
| 1250 |
+
|
| 1251 |
+
self.H = 512 - (up_outpaint + down_outpaint)
|
| 1252 |
+
self.W = 512 - (left_outpaint + right_outpaint)
|
| 1253 |
+
self.img_size = 512
|
| 1254 |
+
|
| 1255 |
+
self.img_lists = [img_data]
|
| 1256 |
+
|
| 1257 |
+
def __len__(self):
|
| 1258 |
+
return len(self.img_lists)
|
| 1259 |
+
|
| 1260 |
+
def __getitem__(self, index):
|
| 1261 |
+
# img_name = self.img_lists[index].split("/")[-1]
|
| 1262 |
+
|
| 1263 |
+
# crop region
|
| 1264 |
+
box = [
|
| 1265 |
+
self.left_outpaint,
|
| 1266 |
+
511 - self.right_outpaint,
|
| 1267 |
+
self.up_outpaint,
|
| 1268 |
+
511 - self.down_outpaint,
|
| 1269 |
+
] # [left, right, top, bottom]
|
| 1270 |
+
|
| 1271 |
+
# input image to be outpainted
|
| 1272 |
+
input_img = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
|
| 1273 |
+
paste_img = np.array(self.img_lists[index].resize((self.W, self.H), Image.LANCZOS))
|
| 1274 |
+
input_img[box[2] : box[3] + 1, box[0] : box[1] + 1, :] = paste_img
|
| 1275 |
+
cropped_region = np.ones((self.img_size, self.img_size), dtype=np.uint8)
|
| 1276 |
+
cropped_region[box[2] : box[3] + 1, box[0] : box[1] + 1] = 0
|
| 1277 |
+
input_img[cropped_region == 1] = 128
|
| 1278 |
+
|
| 1279 |
+
# image for blip
|
| 1280 |
+
blip_img = np.array(self.img_lists[index])
|
| 1281 |
+
|
| 1282 |
+
# # noise matching
|
| 1283 |
+
# input_img_matching = None
|
| 1284 |
+
# if self.noise_matching:
|
| 1285 |
+
# np_src_img = input_img / 255.0
|
| 1286 |
+
# np_mask_rgb = np.repeat(cropped_region[:, :, None], 3, axis=2).astype(
|
| 1287 |
+
# np.float32
|
| 1288 |
+
# )
|
| 1289 |
+
# matched_noise = self.get_matched_noise(np_src_img, np_mask_rgb)
|
| 1290 |
+
# input_img_matching = (matched_noise * 255).astype(np.uint8)
|
| 1291 |
+
|
| 1292 |
+
# mask image
|
| 1293 |
+
mask_img = (cropped_region * 255).astype(np.uint8)
|
| 1294 |
+
|
| 1295 |
+
return {
|
| 1296 |
+
"blip_img": blip_img,
|
| 1297 |
+
"input_img": Image.fromarray(input_img),
|
| 1298 |
+
"mask_img": Image.fromarray(mask_img),
|
| 1299 |
+
"box": box,
|
| 1300 |
+
}
|
| 1301 |
+
|
| 1302 |
+
|
| 1303 |
+
if __name__ == "__main__":
|
| 1304 |
+
pass
|
utils/loss.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch import nn
|
| 4 |
+
from scipy.optimize import linear_sum_assignment
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class uncertainty_light_pos_loss(nn.Module):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super(uncertainty_light_pos_loss, self).__init__()
|
| 10 |
+
self.log_var_xyr = nn.Parameter(torch.tensor(1.0, requires_grad=True))
|
| 11 |
+
self.log_var_p = nn.Parameter(torch.tensor(1.0, requires_grad=True))
|
| 12 |
+
|
| 13 |
+
def forward(self, logits, targets):
|
| 14 |
+
B, N, P = logits.shape # (B, 4, 4)
|
| 15 |
+
|
| 16 |
+
position_loss = 0
|
| 17 |
+
confidence_loss = 0
|
| 18 |
+
|
| 19 |
+
w_xyr = 0.5 / (self.log_var_xyr**2) # uncertainty weight for position loss
|
| 20 |
+
w_p = 0.5 / (self.log_var_p**2) # uncertainty weight for confidence loss
|
| 21 |
+
weights = torch.tensor([1, 1, 2], device=logits.device) # weights for x, y, r
|
| 22 |
+
|
| 23 |
+
for b in range(B):
|
| 24 |
+
pred_xyr = logits[b, :, :3] # (N, 3)
|
| 25 |
+
pred_p = logits[b, :, 3] # (N,)
|
| 26 |
+
|
| 27 |
+
gt_xyr = targets[b, :, :3] # (N, 3)
|
| 28 |
+
gt_p = targets[b, :, 3] # (N,)
|
| 29 |
+
|
| 30 |
+
cost_matrix = torch.cdist(gt_xyr, pred_xyr, p=2) # (N, N)
|
| 31 |
+
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().numpy())
|
| 34 |
+
|
| 35 |
+
matched_pred_xyr = pred_xyr[col_ind]
|
| 36 |
+
matched_gt_xyr = gt_xyr[row_ind]
|
| 37 |
+
matched_pred_p = pred_p[col_ind]
|
| 38 |
+
matched_gt_p = gt_p[row_ind]
|
| 39 |
+
|
| 40 |
+
valid_mask = matched_gt_p > 0
|
| 41 |
+
valid_cnt = valid_mask.sum().clamp(min=1)
|
| 42 |
+
|
| 43 |
+
xyr_loss = (
|
| 44 |
+
F.smooth_l1_loss(
|
| 45 |
+
matched_pred_xyr[valid_mask],
|
| 46 |
+
matched_gt_xyr[valid_mask],
|
| 47 |
+
reduction="none",
|
| 48 |
+
)
|
| 49 |
+
* weights
|
| 50 |
+
).sum()
|
| 51 |
+
|
| 52 |
+
p_loss = F.binary_cross_entropy(
|
| 53 |
+
matched_pred_p, matched_gt_p, reduction="mean"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
position_loss += xyr_loss / valid_cnt
|
| 57 |
+
confidence_loss += p_loss
|
| 58 |
+
|
| 59 |
+
position_loss = w_xyr * (position_loss / B) + torch.log(1 + self.log_var_xyr**2)
|
| 60 |
+
confidence_loss = w_p * (confidence_loss / B) + torch.log(1 + self.log_var_p**2)
|
| 61 |
+
|
| 62 |
+
return position_loss, confidence_loss
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class unet_3maps_loss(nn.Module):
|
| 66 |
+
def __init__(self):
|
| 67 |
+
super(unet_3maps_loss, self).__init__()
|
| 68 |
+
|
| 69 |
+
def forward(self, pred_prob, pred_rad, prob_gt, rad_gt):
|
| 70 |
+
focal = nn.BCELoss()
|
| 71 |
+
L_prob = focal(pred_prob, prob_gt)
|
| 72 |
+
|
| 73 |
+
pos_mask = prob_gt > 0.5
|
| 74 |
+
L_rad = (
|
| 75 |
+
nn.functional.smooth_l1_loss(pred_rad[pos_mask], rad_gt[pos_mask])
|
| 76 |
+
if pos_mask.any()
|
| 77 |
+
else pred_rad.sum() * 0
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
return L_prob + 10.0 * L_rad, L_prob, L_rad
|
utils/utils.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import skimage
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
import torchvision.transforms.functional as TF
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from skimage.draw import disk
|
| 11 |
+
from skimage import morphology
|
| 12 |
+
from collections import OrderedDict
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_mfdnet_checkpoint(model, weights):
|
| 16 |
+
checkpoint = torch.load(weights, map_location=lambda storage, loc: storage.cuda(0))
|
| 17 |
+
new_state_dict = OrderedDict()
|
| 18 |
+
for key, value in checkpoint["state_dict"].items():
|
| 19 |
+
if key.startswith("module"):
|
| 20 |
+
name = key[7:]
|
| 21 |
+
else:
|
| 22 |
+
name = key
|
| 23 |
+
new_state_dict[name] = value
|
| 24 |
+
model.load_state_dict(new_state_dict)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def adjust_gamma(image: torch.Tensor, gamma):
|
| 28 |
+
# image is in shape of [B,C,H,W] and gamma is in shape [B]
|
| 29 |
+
gamma = gamma.float().cuda()
|
| 30 |
+
gamma_tensor = torch.ones_like(image)
|
| 31 |
+
gamma_tensor = gamma.view(-1, 1, 1, 1) * gamma_tensor
|
| 32 |
+
image = torch.pow(image, gamma_tensor)
|
| 33 |
+
out = torch.clamp(image, 0.0, 1.0)
|
| 34 |
+
return out
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def adjust_gamma_reverse(image: torch.Tensor, gamma):
|
| 38 |
+
# gamma=torch.Tensor([gamma]).cuda()
|
| 39 |
+
gamma = 1 / gamma.float().cuda()
|
| 40 |
+
gamma_tensor = torch.ones_like(image)
|
| 41 |
+
gamma_tensor = gamma.view(-1, 1, 1, 1) * gamma_tensor
|
| 42 |
+
image = torch.pow(image, gamma_tensor)
|
| 43 |
+
out = torch.clamp(image, 0.0, 1.0)
|
| 44 |
+
return out
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def predict_flare_from_6_channel(input_tensor, gamma):
|
| 48 |
+
# the input is a tensor in [B,C,H,W], the C here is 6
|
| 49 |
+
|
| 50 |
+
deflare_img = input_tensor[:, :3, :, :]
|
| 51 |
+
flare_img_predicted = input_tensor[:, 3:, :, :]
|
| 52 |
+
|
| 53 |
+
merge_img_predicted_linear = adjust_gamma(deflare_img, gamma) + adjust_gamma(
|
| 54 |
+
flare_img_predicted, gamma
|
| 55 |
+
)
|
| 56 |
+
merge_img_predicted = adjust_gamma_reverse(
|
| 57 |
+
torch.clamp(merge_img_predicted_linear, 1e-7, 1.0), gamma
|
| 58 |
+
)
|
| 59 |
+
return deflare_img, flare_img_predicted, merge_img_predicted
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def predict_flare_from_3_channel(
|
| 63 |
+
input_tensor, flare_mask, base_img, flare_img, merge_img, gamma
|
| 64 |
+
):
|
| 65 |
+
# the input is a tensor in [B,C,H,W], the C here is 3
|
| 66 |
+
|
| 67 |
+
input_tensor_linear = adjust_gamma(input_tensor, gamma)
|
| 68 |
+
merge_tensor_linear = adjust_gamma(merge_img, gamma)
|
| 69 |
+
flare_img_predicted = adjust_gamma_reverse(
|
| 70 |
+
torch.clamp(merge_tensor_linear - input_tensor_linear, 1e-7, 1.0), gamma
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
masked_deflare_img = input_tensor * (1 - flare_mask) + base_img * flare_mask
|
| 74 |
+
masked_flare_img_predicted = (
|
| 75 |
+
flare_img_predicted * (1 - flare_mask) + flare_img * flare_mask
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return masked_deflare_img, masked_flare_img_predicted
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_highlight_mask(image, threshold=0.99, luminance_mode=False):
|
| 82 |
+
"""Get the area close to the exposure
|
| 83 |
+
Args:
|
| 84 |
+
image: the image tensor in [B,C,H,W]. For inference, B is set as 1.
|
| 85 |
+
threshold: the threshold of luminance/greyscale of exposure region
|
| 86 |
+
luminance_mode: use luminance or greyscale
|
| 87 |
+
Return:
|
| 88 |
+
Binary image in [B,H,W]
|
| 89 |
+
"""
|
| 90 |
+
if luminance_mode:
|
| 91 |
+
# 3 channels in RGB
|
| 92 |
+
luminance = (
|
| 93 |
+
0.2126 * image[:, 0, :, :]
|
| 94 |
+
+ 0.7152 * image[:, 1, :, :]
|
| 95 |
+
+ 0.0722 * image[:, 2, :, :]
|
| 96 |
+
)
|
| 97 |
+
binary_mask = luminance > threshold
|
| 98 |
+
else:
|
| 99 |
+
binary_mask = image.mean(dim=1, keepdim=True) > threshold
|
| 100 |
+
binary_mask = binary_mask.to(image.dtype)
|
| 101 |
+
return binary_mask
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def refine_mask(mask, morph_size=0.01):
|
| 105 |
+
"""Refines a mask by applying mophological operations.
|
| 106 |
+
Args:
|
| 107 |
+
mask: A float array of shape [H, W]
|
| 108 |
+
morph_size: Size of the morphological kernel relative to the long side of
|
| 109 |
+
the image.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Refined mask of shape [H, W].
|
| 113 |
+
"""
|
| 114 |
+
mask_size = max(np.shape(mask))
|
| 115 |
+
kernel_radius = 0.5 * morph_size * mask_size
|
| 116 |
+
kernel = morphology.disk(np.ceil(kernel_radius))
|
| 117 |
+
opened = morphology.binary_opening(mask, kernel)
|
| 118 |
+
return opened
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _create_disk_kernel(kernel_size):
|
| 122 |
+
_EPS = 1e-7
|
| 123 |
+
x = np.arange(kernel_size) - (kernel_size - 1) / 2
|
| 124 |
+
xx, yy = np.meshgrid(x, x)
|
| 125 |
+
rr = np.sqrt(xx**2 + yy**2)
|
| 126 |
+
kernel = np.float32(rr <= np.max(x)) + _EPS
|
| 127 |
+
kernel = kernel / np.sum(kernel)
|
| 128 |
+
return kernel
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def blend_light_source(input_scene, pred_scene, threshold=0.99, luminance_mode=False):
|
| 132 |
+
binary_mask = (
|
| 133 |
+
get_highlight_mask(
|
| 134 |
+
input_scene, threshold=threshold, luminance_mode=luminance_mode
|
| 135 |
+
)
|
| 136 |
+
> 0.5
|
| 137 |
+
).to("cpu", torch.bool)
|
| 138 |
+
binary_mask = binary_mask.squeeze() # (h, w)
|
| 139 |
+
binary_mask = binary_mask.numpy()
|
| 140 |
+
binary_mask = refine_mask(binary_mask)
|
| 141 |
+
|
| 142 |
+
labeled = skimage.measure.label(binary_mask)
|
| 143 |
+
properties = skimage.measure.regionprops(labeled)
|
| 144 |
+
max_diameter = 0
|
| 145 |
+
for p in properties:
|
| 146 |
+
# The diameter of a circle with the same area as the region.
|
| 147 |
+
max_diameter = max(max_diameter, p["equivalent_diameter"])
|
| 148 |
+
|
| 149 |
+
mask = np.float32(binary_mask)
|
| 150 |
+
kernel_size = round(1.5 * max_diameter) # default is 1.5
|
| 151 |
+
if kernel_size > 0:
|
| 152 |
+
kernel = _create_disk_kernel(kernel_size)
|
| 153 |
+
mask = cv2.filter2D(mask, -1, kernel)
|
| 154 |
+
mask = np.clip(mask * 3.0, 0.0, 1.0)
|
| 155 |
+
mask_rgb = np.stack([mask] * 3, axis=0)
|
| 156 |
+
|
| 157 |
+
mask_rgb = torch.from_numpy(mask_rgb).to(input_scene.device, torch.float32)
|
| 158 |
+
blend = input_scene * mask_rgb + pred_scene * (1 - mask_rgb)
|
| 159 |
+
else:
|
| 160 |
+
blend = pred_scene
|
| 161 |
+
return blend
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def blend_with_alpha(result, input_img, box, blur_size=31):
|
| 165 |
+
"""
|
| 166 |
+
Apply alpha blending to paste the specified box region from input_img onto the result image
|
| 167 |
+
to reduce boundary artifacts and make the blending more natural.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
result (np.array): inpainting generated image
|
| 171 |
+
input_img (np.array): original image
|
| 172 |
+
box (tuple): (x_min, x_max, y_min, y_max) representing the paste-back region from the original image
|
| 173 |
+
blur_size (int): blur range for the mask, larger values create smoother transitions (recommended 15~50)
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
np.array: image after alpha blending
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
x_min, x_max, y_min, y_max = box
|
| 180 |
+
|
| 181 |
+
# alpha mask
|
| 182 |
+
mask = np.zeros_like(result, dtype=np.float32)
|
| 183 |
+
mask[y_min : y_max + 1, x_min : x_max + 1] = 1.0
|
| 184 |
+
|
| 185 |
+
# gaussian blur
|
| 186 |
+
mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0)
|
| 187 |
+
|
| 188 |
+
# alpha blending
|
| 189 |
+
blended = (mask * input_img + (1 - mask) * result).astype(np.uint8)
|
| 190 |
+
|
| 191 |
+
return blended
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def IoU(pred, target):
|
| 195 |
+
assert pred.shape == target.shape, "Prediction and target must have the same shape."
|
| 196 |
+
|
| 197 |
+
intersection = np.logical_and(pred, target).sum()
|
| 198 |
+
union = np.logical_or(pred, target).sum()
|
| 199 |
+
|
| 200 |
+
if union == 0:
|
| 201 |
+
return 1.0 if intersection == 0 else 0.0
|
| 202 |
+
|
| 203 |
+
return intersection / union
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def mean_IoU(y_true, y_pred, num_classes):
|
| 207 |
+
"""
|
| 208 |
+
Calculate the mean Intersection over Union (mIoU) score.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
y_true (np.ndarray): Ground truth labels (integer class values).
|
| 212 |
+
y_pred (np.ndarray): Predicted labels (integer class values).
|
| 213 |
+
num_classes (int): Number of classes.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
float: The mean IoU score across all classes.
|
| 217 |
+
"""
|
| 218 |
+
iou_scores = []
|
| 219 |
+
|
| 220 |
+
for cls in range(num_classes):
|
| 221 |
+
# Create binary masks for the current class
|
| 222 |
+
true_mask = y_true == cls
|
| 223 |
+
pred_mask = y_pred == cls
|
| 224 |
+
|
| 225 |
+
# Calculate intersection and union
|
| 226 |
+
intersection = np.logical_and(true_mask, pred_mask)
|
| 227 |
+
union = np.logical_or(true_mask, pred_mask)
|
| 228 |
+
|
| 229 |
+
# Compute IoU for the current class
|
| 230 |
+
if np.sum(union) == 0:
|
| 231 |
+
# Handle edge case: no samples for this class
|
| 232 |
+
iou_scores.append(np.nan)
|
| 233 |
+
else:
|
| 234 |
+
iou_scores.append(np.sum(intersection) / np.sum(union))
|
| 235 |
+
|
| 236 |
+
# Calculate mean IoU, ignoring NaN values (classes without samples)
|
| 237 |
+
mean_iou = np.nanmean(iou_scores)
|
| 238 |
+
return mean_iou
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def RGB2YCbCr(img):
|
| 242 |
+
img = img * 255.0
|
| 243 |
+
r, g, b = torch.split(img, 1, dim=0)
|
| 244 |
+
y = torch.zeros_like(r)
|
| 245 |
+
cb = torch.zeros_like(r)
|
| 246 |
+
cr = torch.zeros_like(r)
|
| 247 |
+
|
| 248 |
+
y = 0.257 * r + 0.504 * g + 0.098 * b + 16
|
| 249 |
+
y = y / 255.0
|
| 250 |
+
|
| 251 |
+
cb = -0.148 * r - 0.291 * g + 0.439 * b + 128
|
| 252 |
+
cb = cb / 255.0
|
| 253 |
+
|
| 254 |
+
cr = 0.439 * r - 0.368 * g - 0.071 * b + 128
|
| 255 |
+
cr = cr / 255.0
|
| 256 |
+
|
| 257 |
+
img = torch.cat([y, y, y], dim=0)
|
| 258 |
+
return img
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def extract_peaks(prob_map, thr=0.5, pool=7):
|
| 262 |
+
"""
|
| 263 |
+
prob_map: (H, W) after sigmoid
|
| 264 |
+
return: tensor of peak coordinates [K, 2] (x, y)
|
| 265 |
+
"""
|
| 266 |
+
# binary mask
|
| 267 |
+
pos = prob_map > thr
|
| 268 |
+
|
| 269 |
+
# non‑maximum suppression
|
| 270 |
+
nms = F.max_pool2d(
|
| 271 |
+
prob_map.unsqueeze(0).unsqueeze(0),
|
| 272 |
+
kernel_size=pool,
|
| 273 |
+
stride=1,
|
| 274 |
+
padding=pool // 2,
|
| 275 |
+
)
|
| 276 |
+
peaks = (prob_map == nms.squeeze()) & pos
|
| 277 |
+
ys, xs = torch.nonzero(peaks, as_tuple=True)
|
| 278 |
+
return torch.stack([xs, ys], dim=1) # (K, 2)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def pick_radius(radius_map, centers, ksize=3):
|
| 282 |
+
"""
|
| 283 |
+
radius_map: (H, W) ∈ [0, 1]
|
| 284 |
+
centers: (K, 2) x,y
|
| 285 |
+
return: (K,) radii in pixel
|
| 286 |
+
"""
|
| 287 |
+
# H, W = radius_map.shape
|
| 288 |
+
pad = ksize // 2
|
| 289 |
+
padded = F.pad(
|
| 290 |
+
radius_map.unsqueeze(0).unsqueeze(0), (pad, pad, pad, pad), mode="reflect"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
radii = []
|
| 294 |
+
for x, y in centers:
|
| 295 |
+
patch = padded[..., y : y + ksize, x : x + ksize]
|
| 296 |
+
radii.append(patch.mean()) # 3×3 mean
|
| 297 |
+
return torch.stack(radii)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def draw_mask(centers, radii, H, W):
|
| 301 |
+
"""
|
| 302 |
+
centers: (K, 2) (x, y)
|
| 303 |
+
radii: (K,)
|
| 304 |
+
return: (H, W) uint8 mask
|
| 305 |
+
"""
|
| 306 |
+
radii *= 256
|
| 307 |
+
mask = np.zeros((H, W), dtype=np.float32)
|
| 308 |
+
for (x, y), r in zip(centers, radii):
|
| 309 |
+
rr, cc = disk((y.item(), x.item()), r.item(), shape=mask.shape)
|
| 310 |
+
mask[rr, cc] = 1
|
| 311 |
+
return mask
|
weights/light_outpaint_lora/pytorch_lora_weights.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:04aeb7148ae4d8c59f0d0260ee813c2fe41a8392d826c4941dfda9ed7cf7090d
|
| 3 |
+
size 3358448
|
weights/light_regress/model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c4e2ac2d23180814361ec04bcb22cc92adb761fb5ccc761b5c3874a297fed18
|
| 3 |
+
size 85314151
|
weights/net_g_last.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:75f0fc77ab43703c7a9c7876621f8a651d6ce3a0cfb7c6e2377b3c8e2331b0e2
|
| 3 |
+
size 82605273
|