Spaces:
Runtime error
Runtime error
BraveLizzy
commited on
Commit
•
7af68d5
1
Parent(s):
cc34264
Upload Underwater.py
Browse files- model/Underwater.py +101 -0
model/Underwater.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
> Network architecture of FUnIE-GAN model
|
3 |
+
* Paper: arxiv.org/pdf/1903.09766.pdf
|
4 |
+
> Maintainer: https://github.com/xahidbuffon
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class UNetDown(nn.Module):
|
12 |
+
def __init__(self, in_size, out_size, bn=True):
|
13 |
+
super(UNetDown, self).__init__()
|
14 |
+
layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
|
15 |
+
if bn: layers.append(nn.BatchNorm2d(out_size, momentum=0.8))
|
16 |
+
layers.append(nn.LeakyReLU(0.2))
|
17 |
+
self.model = nn.Sequential(*layers)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
return self.model(x)
|
21 |
+
|
22 |
+
|
23 |
+
class UNetUp(nn.Module):
|
24 |
+
def __init__(self, in_size, out_size):
|
25 |
+
super(UNetUp, self).__init__()
|
26 |
+
layers = [
|
27 |
+
nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
|
28 |
+
nn.BatchNorm2d(out_size, momentum=0.8),
|
29 |
+
nn.ReLU(inplace=True),
|
30 |
+
]
|
31 |
+
self.model = nn.Sequential(*layers)
|
32 |
+
|
33 |
+
def forward(self, x, skip_input):
|
34 |
+
x = self.model(x)
|
35 |
+
x = torch.cat((x, skip_input), 1)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class GeneratorFunieGAN(nn.Module):
|
40 |
+
""" A 5-layer UNet-based generator as described in the paper
|
41 |
+
"""
|
42 |
+
def __init__(self, in_channels=3, out_channels=3):
|
43 |
+
super(GeneratorFunieGAN, self).__init__()
|
44 |
+
# encoding layers
|
45 |
+
self.down1 = UNetDown(in_channels, 32, bn=False)
|
46 |
+
self.down2 = UNetDown(32, 128)
|
47 |
+
self.down3 = UNetDown(128, 256)
|
48 |
+
self.down4 = UNetDown(256, 256)
|
49 |
+
self.down5 = UNetDown(256, 256, bn=False)
|
50 |
+
# decoding layers
|
51 |
+
self.up1 = UNetUp(256, 256)
|
52 |
+
self.up2 = UNetUp(512, 256)
|
53 |
+
self.up3 = UNetUp(512, 128)
|
54 |
+
self.up4 = UNetUp(256, 32)
|
55 |
+
self.final = nn.Sequential(
|
56 |
+
nn.Upsample(scale_factor=2),
|
57 |
+
nn.ZeroPad2d((1, 0, 1, 0)),
|
58 |
+
nn.Conv2d(64, out_channels, 4, padding=1),
|
59 |
+
nn.Tanh(),
|
60 |
+
)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
d1 = self.down1(x)
|
64 |
+
d2 = self.down2(d1)
|
65 |
+
d3 = self.down3(d2)
|
66 |
+
d4 = self.down4(d3)
|
67 |
+
d5 = self.down5(d4)
|
68 |
+
u1 = self.up1(d5, d4)
|
69 |
+
u2 = self.up2(u1, d3)
|
70 |
+
u3 = self.up3(u2, d2)
|
71 |
+
u45 = self.up4(u3, d1)
|
72 |
+
return self.final(u45)
|
73 |
+
|
74 |
+
|
75 |
+
class DiscriminatorFunieGAN(nn.Module):
|
76 |
+
""" A 4-layer Markovian discriminator as described in the paper
|
77 |
+
"""
|
78 |
+
def __init__(self, in_channels=3):
|
79 |
+
super(DiscriminatorFunieGAN, self).__init__()
|
80 |
+
|
81 |
+
def discriminator_block(in_filters, out_filters, bn=True):
|
82 |
+
#Returns downsampling layers of each discriminator block
|
83 |
+
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
|
84 |
+
if bn: layers.append(nn.BatchNorm2d(out_filters, momentum=0.8))
|
85 |
+
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
86 |
+
return layers
|
87 |
+
|
88 |
+
self.model = nn.Sequential(
|
89 |
+
*discriminator_block(in_channels*2, 32, bn=False),
|
90 |
+
*discriminator_block(32, 64),
|
91 |
+
*discriminator_block(64, 128),
|
92 |
+
*discriminator_block(128, 256),
|
93 |
+
nn.ZeroPad2d((1, 0, 1, 0)),
|
94 |
+
nn.Conv2d(256, 1, 4, padding=1, bias=False)
|
95 |
+
)
|
96 |
+
|
97 |
+
def forward(self, img_A, img_B):
|
98 |
+
# Concatenate image and condition image by channels to produce input
|
99 |
+
img_input = torch.cat((img_A, img_B), 1)
|
100 |
+
return self.model(img_input)
|
101 |
+
|