lunde commited on
Commit
0691d6d
1 Parent(s): e130db9

Adding first version of app

Browse files
__pycache__/model.cpython-39.pyc ADDED
Binary file (2.34 kB). View file
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim
3
+ import model
4
+ import numpy as np
5
+ from PIL import Image
6
+ import streamlit as st
7
+ from torchvision import transforms
8
+
9
+ scale_factor = 1
10
+
11
+ @st.cache
12
+ def load_model() -> torch.nn.Module:
13
+ DCE_net = model.enhance_net_nopool(scale_factor)
14
+ DCE_net.load_state_dict(torch.load("lowlight-dce-snapshot.pth", map_location=torch.device('cpu')))
15
+
16
+ return DCE_net
17
+
18
+ def fix_lowlight(image: Image.Image) -> Image.Image:
19
+ DCE_net = load_model()
20
+ data_lowlight = np.asarray(image) / 255.0
21
+
22
+ data_lowlight = torch.from_numpy(data_lowlight).float()
23
+
24
+ h = (data_lowlight.shape[0] // scale_factor) * scale_factor
25
+ w = (data_lowlight.shape[1] // scale_factor) * scale_factor
26
+ data_lowlight = data_lowlight[0:h, 0:w, :]
27
+ data_lowlight = data_lowlight.permute(2, 0, 1)
28
+ data_lowlight = data_lowlight.unsqueeze(0)
29
+
30
+ enhanced_image, _ = DCE_net(data_lowlight)
31
+ im = transforms.ToPILImage()(enhanced_image[0]).convert("RGB")
32
+
33
+ return im
34
+
35
+
36
+ def main():
37
+ st.title("Lowlight Enhancement")
38
+ st.write("This is a simple lowlight enhancement app with great performance and does not require paired images to train.")
39
+ uploaded_file = st.file_uploader("Lowlight Image")
40
+ if uploaded_file:
41
+ data_lowlight = Image.open(uploaded_file)
42
+
43
+ fixed_img = fix_lowlight(data_lowlight)
44
+
45
+ st.image(fixed_img, caption="Enhanced Image", use_column_width=True)
46
+
47
+ main()
loss.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.models.vgg import vgg16
5
+
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+
8
+ class L_color(nn.Module):
9
+ def __init__(self):
10
+ super(L_color, self).__init__()
11
+
12
+ def forward(self, x ):
13
+ b,c,h,w = x.shape
14
+
15
+ mean_rgb = torch.mean(x, [2, 3], keepdim=True)
16
+ mr, mg, mb = torch.split(mean_rgb, 1, dim=1)
17
+ Drg = torch.pow(mr-mg, 4)
18
+ Drb = torch.pow(mr-mb, 4)
19
+ Dgb = torch.pow(mb-mg, 4)
20
+
21
+ return torch.sqrt(Drg + Drb + Dgb)
22
+
23
+
24
+ class L_spa(nn.Module):
25
+ def __init__(self):
26
+ super(L_spa, self).__init__()
27
+ kernel_left = torch.tensor( [[0,0,0], [-1,1,0], [0,0,0]], dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0)
28
+ kernel_right = torch.tensor( [[0,0,0], [0,1,-1], [0,0,0]], dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0)
29
+ kernel_up = torch.tensor( [[0,-1,0], [0,1,0], [0,0,0]], dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0)
30
+ kernel_down = torch.tensor( [[0,0,0], [0,1,0], [0,-1,0]], dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0)
31
+
32
+ self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)
33
+ self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
34
+ self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)
35
+ self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)
36
+ self.pool = nn.AvgPool2d(4)
37
+
38
+ def forward(self, org , enhance ):
39
+ b,c,h,w = org.shape
40
+
41
+ org_mean = torch.mean(org, 1, keepdim=True)
42
+ enhance_mean = torch.mean(enhance, 1, keepdim=True)
43
+
44
+ org_pool = self.pool(org_mean)
45
+ enhance_pool = self.pool(enhance_mean)
46
+
47
+ D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
48
+ D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
49
+ D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
50
+ D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)
51
+
52
+ D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
53
+ D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
54
+ D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
55
+ D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)
56
+
57
+ D_left = torch.pow(D_org_letf - D_enhance_letf,2)
58
+ D_right = torch.pow(D_org_right - D_enhance_right,2)
59
+ D_up = torch.pow(D_org_up - D_enhance_up,2)
60
+ D_down = torch.pow(D_org_down - D_enhance_down,2)
61
+ E = (D_left + D_right + D_up +D_down)
62
+ # E = 25*(D_left + D_right + D_up +D_down)
63
+
64
+ return E
65
+
66
+ class L_exp(nn.Module):
67
+ def __init__(self,patch_size):
68
+ super(L_exp, self).__init__()
69
+ self.pool = nn.AvgPool2d(patch_size)
70
+ # self.mean_val = mean_val
71
+ def forward(self, x, mean_val):
72
+ b,c,h,w = x.shape
73
+ x = torch.mean(x,1,keepdim=True)
74
+ mean = self.pool(x)
75
+
76
+ d = torch.mean(torch.pow(mean - torch.tensor([mean_val], dtype=torch.float, device=device).to(), 2))
77
+ return d
78
+
79
+ class L_TV(nn.Module):
80
+ def __init__(self,TVLoss_weight=1):
81
+ super(L_TV,self).__init__()
82
+ self.TVLoss_weight = TVLoss_weight
83
+
84
+ def forward(self,x):
85
+ batch_size = x.size()[0]
86
+ h_x = x.size()[2]
87
+ w_x = x.size()[3]
88
+ count_h = (x.size()[2]-1) * x.size()[3]
89
+ count_w = x.size()[2] * (x.size()[3] - 1)
90
+ h_tv = torch.pow((x[:,:,1:,:] - x[:,:,:h_x-1,:]),2).sum()
91
+ w_tv = torch.pow((x[:,:,:,1:] - x[:,:,:,:w_x-1]),2).sum()
92
+ return self.TVLoss_weight*2*(h_tv / count_h + w_tv / count_w) / batch_size
93
+
94
+ class Sa_Loss(nn.Module):
95
+ def __init__(self):
96
+ super(Sa_Loss, self).__init__()
97
+
98
+ def forward(self, x ):
99
+ b,c,h,w = x.shape
100
+ # x_de = x.cpu().detach().numpy()
101
+ r,g,b = torch.split(x , 1, dim=1)
102
+ mean_rgb = torch.mean(x,[2,3],keepdim=True)
103
+ mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
104
+ Dr = r-mr
105
+ Dg = g-mg
106
+ Db = b-mb
107
+ k =torch.pow( torch.pow(Dr,2) + torch.pow(Db,2) + torch.pow(Dg,2),0.5)
108
+
109
+ k = torch.mean(k)
110
+ return k
111
+
112
+ class perception_loss(nn.Module):
113
+ def __init__(self):
114
+ super(perception_loss, self).__init__()
115
+ features = vgg16(pretrained=True).features
116
+ self.to_relu_1_2 = nn.Sequential()
117
+ self.to_relu_2_2 = nn.Sequential()
118
+ self.to_relu_3_3 = nn.Sequential()
119
+ self.to_relu_4_3 = nn.Sequential()
120
+
121
+ for x in range(4):
122
+ self.to_relu_1_2.add_module(str(x), features[x])
123
+ for x in range(4, 9):
124
+ self.to_relu_2_2.add_module(str(x), features[x])
125
+ for x in range(9, 16):
126
+ self.to_relu_3_3.add_module(str(x), features[x])
127
+ for x in range(16, 23):
128
+ self.to_relu_4_3.add_module(str(x), features[x])
129
+
130
+ for param in self.parameters():
131
+ param.requires_grad = False
132
+
133
+ def forward(self, x):
134
+ h = self.to_relu_1_2(x)
135
+ h = self.to_relu_2_2(h)
136
+ h = self.to_relu_3_3(h)
137
+ h = self.to_relu_4_3(h)
138
+
139
+ return h
lowlight-dce-snapshot.pth ADDED
Binary file (51.2 kB). View file
model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class CSDN_Tem(nn.Module):
7
+ def __init__(self, in_ch, out_ch):
8
+ super(CSDN_Tem, self).__init__()
9
+ self.depth_conv = nn.Conv2d(
10
+ in_channels=in_ch,
11
+ out_channels=in_ch,
12
+ kernel_size=3,
13
+ padding=1,
14
+ groups=in_ch
15
+ )
16
+ self.point_conv = nn.Conv2d(
17
+ in_channels=in_ch,
18
+ out_channels=out_ch,
19
+ kernel_size=1
20
+ )
21
+
22
+ def forward(self, input):
23
+ out = self.depth_conv(input)
24
+ out = self.point_conv(out)
25
+ return out
26
+
27
+ class enhance_net_nopool(nn.Module):
28
+ def __init__(self,scale_factor):
29
+ super(enhance_net_nopool, self).__init__()
30
+
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.scale_factor = scale_factor
33
+ self.upsample = nn.UpsamplingBilinear2d(scale_factor=self.scale_factor)
34
+ number_f = 32
35
+
36
+ # zerodce DWC + p-shared
37
+ self.e_conv1 = CSDN_Tem(3, number_f)
38
+ self.e_conv2 = CSDN_Tem(number_f, number_f)
39
+ self.e_conv3 = CSDN_Tem(number_f, number_f)
40
+ self.e_conv4 = CSDN_Tem(number_f, number_f)
41
+ self.e_conv5 = CSDN_Tem(number_f * 2, number_f)
42
+ self.e_conv6 = CSDN_Tem(number_f * 2, number_f)
43
+ self.e_conv7 = CSDN_Tem(number_f * 2, 3)
44
+
45
+ def enhance(self, x, x_r):
46
+ for _ in range(8): x = x + x_r * (torch.pow(x, 2) - x)
47
+
48
+ return x
49
+
50
+ def forward(self, x):
51
+ x_down = x if self.scale_factor==1 else F.interpolate(x, scale_factor = 1 / self.scale_factor, mode='bilinear')
52
+
53
+ x1 = self.relu(self.e_conv1(x_down))
54
+ x2 = self.relu(self.e_conv2(x1))
55
+ x3 = self.relu(self.e_conv3(x2))
56
+ x4 = self.relu(self.e_conv4(x3))
57
+ x5 = self.relu(self.e_conv5(torch.cat([x3, x4], 1)))
58
+ x6 = self.relu(self.e_conv6(torch.cat([x2, x5], 1)))
59
+ x_r = torch.tanh(self.e_conv7(torch.cat([x1, x6], 1)))
60
+
61
+ x_r = x_r if self.scale_factor==1 else self.upsample(x_r)
62
+ enhance_image = self.enhance(x, x_r)
63
+
64
+ return enhance_image, x_r
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ torch==1.9.0
2
+ torchvision==0.2.2