BraveLizzy commited on
Commit
7af68d5
1 Parent(s): cc34264

Upload Underwater.py

Browse files
Files changed (1) hide show
  1. model/Underwater.py +101 -0
model/Underwater.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ > Network architecture of FUnIE-GAN model
3
+ * Paper: arxiv.org/pdf/1903.09766.pdf
4
+ > Maintainer: https://github.com/xahidbuffon
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class UNetDown(nn.Module):
12
+ def __init__(self, in_size, out_size, bn=True):
13
+ super(UNetDown, self).__init__()
14
+ layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
15
+ if bn: layers.append(nn.BatchNorm2d(out_size, momentum=0.8))
16
+ layers.append(nn.LeakyReLU(0.2))
17
+ self.model = nn.Sequential(*layers)
18
+
19
+ def forward(self, x):
20
+ return self.model(x)
21
+
22
+
23
+ class UNetUp(nn.Module):
24
+ def __init__(self, in_size, out_size):
25
+ super(UNetUp, self).__init__()
26
+ layers = [
27
+ nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
28
+ nn.BatchNorm2d(out_size, momentum=0.8),
29
+ nn.ReLU(inplace=True),
30
+ ]
31
+ self.model = nn.Sequential(*layers)
32
+
33
+ def forward(self, x, skip_input):
34
+ x = self.model(x)
35
+ x = torch.cat((x, skip_input), 1)
36
+ return x
37
+
38
+
39
+ class GeneratorFunieGAN(nn.Module):
40
+ """ A 5-layer UNet-based generator as described in the paper
41
+ """
42
+ def __init__(self, in_channels=3, out_channels=3):
43
+ super(GeneratorFunieGAN, self).__init__()
44
+ # encoding layers
45
+ self.down1 = UNetDown(in_channels, 32, bn=False)
46
+ self.down2 = UNetDown(32, 128)
47
+ self.down3 = UNetDown(128, 256)
48
+ self.down4 = UNetDown(256, 256)
49
+ self.down5 = UNetDown(256, 256, bn=False)
50
+ # decoding layers
51
+ self.up1 = UNetUp(256, 256)
52
+ self.up2 = UNetUp(512, 256)
53
+ self.up3 = UNetUp(512, 128)
54
+ self.up4 = UNetUp(256, 32)
55
+ self.final = nn.Sequential(
56
+ nn.Upsample(scale_factor=2),
57
+ nn.ZeroPad2d((1, 0, 1, 0)),
58
+ nn.Conv2d(64, out_channels, 4, padding=1),
59
+ nn.Tanh(),
60
+ )
61
+
62
+ def forward(self, x):
63
+ d1 = self.down1(x)
64
+ d2 = self.down2(d1)
65
+ d3 = self.down3(d2)
66
+ d4 = self.down4(d3)
67
+ d5 = self.down5(d4)
68
+ u1 = self.up1(d5, d4)
69
+ u2 = self.up2(u1, d3)
70
+ u3 = self.up3(u2, d2)
71
+ u45 = self.up4(u3, d1)
72
+ return self.final(u45)
73
+
74
+
75
+ class DiscriminatorFunieGAN(nn.Module):
76
+ """ A 4-layer Markovian discriminator as described in the paper
77
+ """
78
+ def __init__(self, in_channels=3):
79
+ super(DiscriminatorFunieGAN, self).__init__()
80
+
81
+ def discriminator_block(in_filters, out_filters, bn=True):
82
+ #Returns downsampling layers of each discriminator block
83
+ layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
84
+ if bn: layers.append(nn.BatchNorm2d(out_filters, momentum=0.8))
85
+ layers.append(nn.LeakyReLU(0.2, inplace=True))
86
+ return layers
87
+
88
+ self.model = nn.Sequential(
89
+ *discriminator_block(in_channels*2, 32, bn=False),
90
+ *discriminator_block(32, 64),
91
+ *discriminator_block(64, 128),
92
+ *discriminator_block(128, 256),
93
+ nn.ZeroPad2d((1, 0, 1, 0)),
94
+ nn.Conv2d(256, 1, 4, padding=1, bias=False)
95
+ )
96
+
97
+ def forward(self, img_A, img_B):
98
+ # Concatenate image and condition image by channels to produce input
99
+ img_input = torch.cat((img_A, img_B), 1)
100
+ return self.model(img_input)
101
+