IanNathaniel commited on
Commit
a2de11c
1 Parent(s): 80e7256

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +59 -0
model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ #import pytorch_colors as colors
6
+ import numpy as np
7
+
8
+ class enhance_net_nopool(nn.Module):
9
+
10
+ def __init__(self):
11
+ super(enhance_net_nopool, self).__init__()
12
+
13
+ self.relu = nn.ReLU(inplace=True)
14
+
15
+ number_f = 32
16
+ self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True)
17
+ self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
18
+ self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
19
+ self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
20
+ self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
21
+ self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
22
+ self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True)
23
+
24
+ self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
25
+ self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
26
+
27
+
28
+
29
+ def forward(self, x):
30
+
31
+ x1 = self.relu(self.e_conv1(x))
32
+ # p1 = self.maxpool(x1)
33
+ x2 = self.relu(self.e_conv2(x1))
34
+ # p2 = self.maxpool(x2)
35
+ x3 = self.relu(self.e_conv3(x2))
36
+ # p3 = self.maxpool(x3)
37
+ x4 = self.relu(self.e_conv4(x3))
38
+
39
+ x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
40
+ # x5 = self.upsample(x5)
41
+ x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))
42
+
43
+ x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))
44
+ r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)
45
+
46
+
47
+ x = x + r1*(torch.pow(x,2)-x)
48
+ x = x + r2*(torch.pow(x,2)-x)
49
+ x = x + r3*(torch.pow(x,2)-x)
50
+ enhance_image_1 = x + r4*(torch.pow(x,2)-x)
51
+ x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)
52
+ x = x + r6*(torch.pow(x,2)-x)
53
+ x = x + r7*(torch.pow(x,2)-x)
54
+ enhance_image = x + r8*(torch.pow(x,2)-x)
55
+ r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)
56
+ return enhance_image_1,enhance_image,r
57
+
58
+
59
+