IanNathaniel commited on
Commit
8975307
1 Parent(s): a2de11c

Upload Myloss.py

Browse files
Files changed (1) hide show
  1. Myloss.py +157 -0
Myloss.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from torchvision.models.vgg import vgg16
6
+ import numpy as np
7
+
8
+
9
+ class L_color(nn.Module):
10
+
11
+ def __init__(self):
12
+ super(L_color, self).__init__()
13
+
14
+ def forward(self, x ):
15
+
16
+ b,c,h,w = x.shape
17
+
18
+ mean_rgb = torch.mean(x,[2,3],keepdim=True)
19
+ mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
20
+ Drg = torch.pow(mr-mg,2)
21
+ Drb = torch.pow(mr-mb,2)
22
+ Dgb = torch.pow(mb-mg,2)
23
+ k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)
24
+
25
+
26
+ return k
27
+
28
+
29
+ class L_spa(nn.Module):
30
+
31
+ def __init__(self):
32
+ super(L_spa, self).__init__()
33
+ # print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
34
+ kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
35
+ kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
36
+ kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
37
+ kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
38
+ self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)
39
+ self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
40
+ self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)
41
+ self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)
42
+ self.pool = nn.AvgPool2d(4)
43
+ def forward(self, org , enhance ):
44
+ b,c,h,w = org.shape
45
+
46
+ org_mean = torch.mean(org,1,keepdim=True)
47
+ enhance_mean = torch.mean(enhance,1,keepdim=True)
48
+
49
+ org_pool = self.pool(org_mean)
50
+ enhance_pool = self.pool(enhance_mean)
51
+
52
+ weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())
53
+ E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)
54
+
55
+
56
+ D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
57
+ D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
58
+ D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
59
+ D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)
60
+
61
+ D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
62
+ D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
63
+ D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
64
+ D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)
65
+
66
+ D_left = torch.pow(D_org_letf - D_enhance_letf,2)
67
+ D_right = torch.pow(D_org_right - D_enhance_right,2)
68
+ D_up = torch.pow(D_org_up - D_enhance_up,2)
69
+ D_down = torch.pow(D_org_down - D_enhance_down,2)
70
+ E = (D_left + D_right + D_up +D_down)
71
+ # E = 25*(D_left + D_right + D_up +D_down)
72
+
73
+ return E
74
+ class L_exp(nn.Module):
75
+
76
+ def __init__(self,patch_size,mean_val):
77
+ super(L_exp, self).__init__()
78
+ # print(1)
79
+ self.pool = nn.AvgPool2d(patch_size)
80
+ self.mean_val = mean_val
81
+ def forward(self, x ):
82
+
83
+ b,c,h,w = x.shape
84
+ x = torch.mean(x,1,keepdim=True)
85
+ mean = self.pool(x)
86
+
87
+ d = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).cuda(),2))
88
+ return d
89
+
90
+ class L_TV(nn.Module):
91
+ def __init__(self,TVLoss_weight=1):
92
+ super(L_TV,self).__init__()
93
+ self.TVLoss_weight = TVLoss_weight
94
+
95
+ def forward(self,x):
96
+ batch_size = x.size()[0]
97
+ h_x = x.size()[2]
98
+ w_x = x.size()[3]
99
+ count_h = (x.size()[2]-1) * x.size()[3]
100
+ count_w = x.size()[2] * (x.size()[3] - 1)
101
+ h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
102
+ w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
103
+ return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
104
+ class Sa_Loss(nn.Module):
105
+ def __init__(self):
106
+ super(Sa_Loss, self).__init__()
107
+ # print(1)
108
+ def forward(self, x ):
109
+ # self.grad = np.ones(x.shape,dtype=np.float32)
110
+ b,c,h,w = x.shape
111
+ # x_de = x.cpu().detach().numpy()
112
+ r,g,b = torch.split(x , 1, dim=1)
113
+ mean_rgb = torch.mean(x,[2,3],keepdim=True)
114
+ mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
115
+ Dr = r-mr
116
+ Dg = g-mg
117
+ Db = b-mb
118
+ k =torch.pow( torch.pow(Dr,2) + torch.pow(Db,2) + torch.pow(Dg,2),0.5)
119
+ # print(k)
120
+
121
+
122
+ k = torch.mean(k)
123
+ return k
124
+
125
+ class perception_loss(nn.Module):
126
+ def __init__(self):
127
+ super(perception_loss, self).__init__()
128
+ features = vgg16(pretrained=True).features
129
+ self.to_relu_1_2 = nn.Sequential()
130
+ self.to_relu_2_2 = nn.Sequential()
131
+ self.to_relu_3_3 = nn.Sequential()
132
+ self.to_relu_4_3 = nn.Sequential()
133
+
134
+ for x in range(4):
135
+ self.to_relu_1_2.add_module(str(x), features[x])
136
+ for x in range(4, 9):
137
+ self.to_relu_2_2.add_module(str(x), features[x])
138
+ for x in range(9, 16):
139
+ self.to_relu_3_3.add_module(str(x), features[x])
140
+ for x in range(16, 23):
141
+ self.to_relu_4_3.add_module(str(x), features[x])
142
+
143
+ # don't need the gradients, just want the features
144
+ for param in self.parameters():
145
+ param.requires_grad = False
146
+
147
+ def forward(self, x):
148
+ h = self.to_relu_1_2(x)
149
+ h_relu_1_2 = h
150
+ h = self.to_relu_2_2(h)
151
+ h_relu_2_2 = h
152
+ h = self.to_relu_3_3(h)
153
+ h_relu_3_3 = h
154
+ h = self.to_relu_4_3(h)
155
+ h_relu_4_3 = h
156
+ # out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3)
157
+ return h_relu_4_3