chenzhicun commited on
Commit
ec08fea
1 Parent(s): b0cf94d

初始化web demo.

Browse files
IdentityLUT33.txt ADDED
The diff for this file is too large to render. See raw diff
IdentityLUT64.txt ADDED
The diff for this file is too large to render. See raw diff
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision import transforms
5
+ from models.models_x import *
6
+ import torchvision_x_functional as TF_x
7
+ import torchvision.transforms.functional as TF
8
+ from torchvision import transforms
9
+ import cv2
10
+ from timm.models.hub import download_cached_file
11
+
12
+
13
+ cuda = True if torch.cuda.is_available() else False
14
+ Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
15
+ trans = transforms.ToTensor()
16
+
17
+
18
+ LUT0 = Generator3DLUT_identity()
19
+ LUT1 = Generator3DLUT_zero()
20
+ LUT2 = Generator3DLUT_zero()
21
+ classifier = Classifier()
22
+ trilinear_ = Tritri()
23
+ if cuda:
24
+ LUT0 = LUT0.cuda()
25
+ LUT1 = LUT1.cuda()
26
+ LUT2 = LUT2.cuda()
27
+ classifier = classifier.cuda()
28
+
29
+ # Load pretrained models
30
+ cache = download_cached_file('https://drive.google.com/uc?export=download&id=1tzeECo1m4MBqvfLv4H4SQ7by4YMEP17H',
31
+ check_hash=False, progress=True)
32
+ LUTs = torch.load(cache, map_location=torch.device('cpu'))
33
+ LUT0.load_state_dict(LUTs["0"])
34
+ LUT1.load_state_dict(LUTs["1"])
35
+ LUT2.load_state_dict(LUTs["2"])
36
+ LUT0.eval()
37
+ LUT1.eval()
38
+ LUT2.eval()
39
+
40
+ cache = download_cached_file('https://drive.google.com/uc?export=download&id=1rQ_p3NMRFxZ52MOYj0jPewYtD3JQTJGi',
41
+ check_hash=False, progress=True)
42
+ classifier.load_state_dict(torch.load(cache, map_location=torch.device('cpu')))
43
+ classifier.eval()
44
+
45
+
46
+ XLUT0 = Generator3DLUT_identity()
47
+ XLUT1 = Generator3DLUT_zero()
48
+ XLUT2 = Generator3DLUT_zero()
49
+ Xclassifier = Classifier()
50
+ Xtrilinear_ = Tritri()
51
+ if cuda:
52
+ XLUT0 = XLUT0.cuda()
53
+ XLUT1 = XLUT1.cuda()
54
+ XLUT2 = XLUT2.cuda()
55
+ Xclassifier = Xclassifier.cuda()
56
+
57
+ # Load pretrained models
58
+ cache = download_cached_file('https://drive.google.com/uc?export=download&id=1ossTzgbgpZL4Jy5uhiRJDGfCWw9vOv0c',
59
+ check_hash=False, progress=True)
60
+ XLUTs = torch.load(cache, map_location=torch.device('cpu'))
61
+ XLUT0.load_state_dict(XLUTs["0"])
62
+ XLUT1.load_state_dict(XLUTs["1"])
63
+ XLUT2.load_state_dict(XLUTs["2"])
64
+ XLUT0.eval()
65
+ XLUT1.eval()
66
+ XLUT2.eval()
67
+
68
+ cache = download_cached_file('https://drive.google.com/uc?export=download&id=1279CoaqQZK-eK83283MERoRxtRbIgRew',
69
+ check_hash=False, progress=True)
70
+ Xclassifier.load_state_dict(torch.load(cache, map_location=torch.device('cpu')))
71
+ Xclassifier.eval()
72
+
73
+
74
+ def generate_LUT(img):
75
+ pred = classifier(img).squeeze()
76
+
77
+ LUT = pred[0] * LUT0.LUT + pred[1] * LUT1.LUT + pred[2] * LUT2.LUT # + pred[3] * LUT3.LUT + pred[4] * LUT4.LUT
78
+
79
+ return LUT
80
+
81
+ def generate_XLUT(img):
82
+ pred = Xclassifier(img).squeeze()
83
+
84
+ XLUT = pred[0] * XLUT0.LUT + pred[1] * XLUT1.LUT + pred[2] * XLUT2.LUT # + pred[3] * LUT3.LUT + pred[4] * LUT4.LUT
85
+
86
+ return XLUT
87
+
88
+
89
+ def inference(ori_image, models_n):
90
+ with torch.no_grad():
91
+ if models_n == 'sRGB':
92
+ # img = Image.open(ori_image)
93
+ # img = TF.to_tensor(img).type(Tensor)
94
+ img = trans(ori_image)
95
+ img = img.unsqueeze(0)
96
+ LUT = generate_LUT(img)
97
+ result = trilinear_(LUT, img)
98
+ result = result.permute(0, 3, 1, 2)
99
+ ndarr = result.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
100
+ im = Image.fromarray(ndarr)
101
+ elif models_n == 'XYZ':
102
+ img = trans(ori_image)
103
+ img = img.unsqueeze(0)
104
+ XLUT = generate_XLUT(img)
105
+ result = Xtrilinear_(XLUT, img)
106
+ result = result.permute(0, 3, 1, 2)
107
+ ndarr = result.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
108
+ im = Image.fromarray(ndarr)
109
+ return im
110
+
111
+
112
+ inputs = [gr.inputs.Image(type='pil', label='待增强图片'),
113
+ gr.inputs.Radio(choices=['sRGB', 'XYZ'], type="value", default="sRGB", label="图片色彩空间")]
114
+ outputs = [gr.outputs.Image(type='pil', label='增强后图片')]
115
+
116
+ title = '基于LUT的图像增强演示'
117
+
118
+ gr.Interface(inference, inputs, outputs, title=title, allow_flagging= 'never',
119
+ examples=[['./examples/example.jpg', 'sRGB']]).launch(enable_queue=True)
examples/example.jpg ADDED
models/models_x.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from doctest import OutputChecker
2
+ from turtle import forward
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.models as models
6
+ import torchvision.transforms as transforms
7
+ from torch.autograd import Variable
8
+ import torch
9
+ import numpy as np
10
+ import math
11
+
12
+ from models.trilinear_test import bing_lut_trilinearInterplt,Tritri
13
+
14
+ from re import I
15
+ import time
16
+ from PIL import Image
17
+ ###########################################
18
+ # use this module for pytorch 1.x,together with trilinear_cpp
19
+ ###########################################
20
+
21
+
22
+ def weights_init_normal_classifier(m):
23
+ classname = m.__class__.__name__
24
+ if classname.find("Conv") != -1:
25
+ torch.nn.init.xavier_normal_(m.weight.data)
26
+
27
+ elif classname.find("BatchNorm2d") != -1 or classname.find("InstanceNorm2d") != -1:
28
+ torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
29
+ torch.nn.init.constant_(m.bias.data, 0.0)
30
+
31
+ class resnet18_224(nn.Module):
32
+
33
+ def __init__(self, out_dim=5, aug_test=False):
34
+ super(resnet18_224, self).__init__()
35
+
36
+ self.aug_test = aug_test
37
+ net = models.resnet18(pretrained=True)
38
+ # self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
39
+ # self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
40
+
41
+ self.upsample = nn.Upsample(size=(224,224),mode='bilinear')
42
+ net.fc = nn.Linear(512, out_dim)
43
+ self.model = net
44
+
45
+
46
+ def forward(self, x):
47
+
48
+ x = self.upsample(x)
49
+ if self.aug_test:
50
+ # x = torch.cat((x, torch.rot90(x, 1, [2, 3]), torch.rot90(x, 3, [2, 3])), 0)
51
+ x = torch.cat((x, torch.flip(x, [3])), 0)
52
+ f = self.model(x)
53
+
54
+ return f
55
+
56
+ ##############################
57
+ # Discriminator
58
+ ##############################
59
+
60
+
61
+ def discriminator_block(in_filters, out_filters, normalization=False):
62
+ """Returns downsampling layers of each discriminator block"""
63
+ layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]
64
+ layers.append(nn.LeakyReLU(0.2))
65
+ if normalization:
66
+ layers.append(nn.InstanceNorm2d(out_filters, affine=True))
67
+ #layers.append(nn.BatchNorm2d(out_filters))
68
+
69
+ return layers
70
+
71
+ class Discriminator(nn.Module):
72
+ def __init__(self, in_channels=3):
73
+ super(Discriminator, self).__init__()
74
+
75
+ self.model = nn.Sequential(
76
+ nn.Upsample(size=(256,256),mode='bilinear'),
77
+ nn.Conv2d(3, 16, 3, stride=2, padding=1),
78
+ nn.LeakyReLU(0.2),
79
+ nn.InstanceNorm2d(16, affine=True),
80
+ *discriminator_block(16, 32),
81
+ *discriminator_block(32, 64),
82
+ *discriminator_block(64, 128),
83
+ *discriminator_block(128, 128),
84
+ #*discriminator_block(128, 128),
85
+ nn.Conv2d(128, 1, 8, padding=0)
86
+ )
87
+
88
+ def forward(self, img_input):
89
+ return self.model(img_input)
90
+
91
+ class Classifier(nn.Module):
92
+ def __init__(self, in_channels=3):
93
+ super(Classifier, self).__init__()
94
+
95
+ self.model = nn.Sequential(
96
+ # nn.Downsample(size=(256,256),mode='bilinear'),
97
+ nn.Upsample(size=(256,256),mode='bilinear'), #original
98
+
99
+ nn.Conv2d(3, 16, 3, stride=2, padding=1),
100
+ nn.LeakyReLU(0.2),
101
+ nn.InstanceNorm2d(16, affine=True),
102
+ *discriminator_block(16, 32, normalization=True),
103
+ *discriminator_block(32, 64, normalization=True),
104
+ *discriminator_block(64, 128, normalization=True),
105
+ *discriminator_block(128, 128),
106
+ #*discriminator_block(128, 128, normalization=True),
107
+ nn.Dropout(p=0.5),
108
+ nn.Conv2d(128, 3, 8, padding=0),
109
+ )
110
+
111
+
112
+ def forward(self, img_input):
113
+ return self.model(img_input)
114
+
115
+
116
+ class Classifier_unpaired(nn.Module):
117
+ def __init__(self, in_channels=3):
118
+ super(Classifier_unpaired, self).__init__()
119
+
120
+ self.model = nn.Sequential(
121
+ nn.Upsample(size=(256,256),mode='bilinear'),
122
+ nn.Conv2d(3, 16, 3, stride=2, padding=1),
123
+ nn.LeakyReLU(0.2),
124
+ nn.InstanceNorm2d(16, affine=True),
125
+ *discriminator_block(16, 32),
126
+ *discriminator_block(32, 64),
127
+ *discriminator_block(64, 128),
128
+ *discriminator_block(128, 128),
129
+ #*discriminator_block(128, 128),
130
+ nn.Conv2d(128, 3, 8, padding=0),
131
+ )
132
+
133
+ def forward(self, img_input):
134
+ return self.model(img_input)
135
+
136
+
137
+ class Generator3DLUT_identity(nn.Module):
138
+ def __init__(self, dim=33):
139
+ super(Generator3DLUT_identity, self).__init__()
140
+ if dim == 33:
141
+ file = open("IdentityLUT33.txt", 'r')
142
+ elif dim == 64:
143
+ file = open("IdentityLUT64.txt", 'r')
144
+ lines = file.readlines()
145
+ buffer = np.zeros((3,dim,dim,dim), dtype=np.float32)
146
+
147
+ for i in range(0,dim):
148
+ for j in range(0,dim):
149
+ for k in range(0,dim):
150
+ n = i * dim*dim + j * dim + k
151
+ x = lines[n].split()
152
+ buffer[0,i,j,k] = float(x[0])
153
+ buffer[1,i,j,k] = float(x[1])
154
+ buffer[2,i,j,k] = float(x[2])
155
+ self.LUT = nn.Parameter(torch.from_numpy(buffer).requires_grad_(True))
156
+ self.TrilinearInterpolation = Tritri()
157
+ # self.trilinearItp = bing_lut_trilinearInterplt()
158
+
159
+
160
+ def forward(self, x):
161
+ _, output = self.TrilinearInterpolation(self.LUT, x)
162
+ # output = self.trilinearItp(self.LUT,x)
163
+
164
+ #self.LUT, output = self.TrilinearInterpolation(self.LUT, x)
165
+ return output
166
+
167
+ class Generator3DLUT_zero(nn.Module):
168
+ def __init__(self, dim=33):
169
+ super(Generator3DLUT_zero, self).__init__()
170
+
171
+ self.LUT = torch.zeros(3,dim,dim,dim, dtype=torch.float)
172
+ self.LUT = nn.Parameter(torch.tensor(self.LUT))
173
+ self.TrilinearInterpolation = Tritri()
174
+ # self.trilinearItp = bing_lut_trilinearInterplt()
175
+
176
+ def forward(self, x):
177
+ _, output = self.TrilinearInterpolation(self.LUT, x)
178
+ # output = self.trilinearItp(self.LUT,x)
179
+
180
+ return output
181
+
182
+ class LUT_all(nn.Module):
183
+ def __init__(self,
184
+ path_LUT="saved_models/LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10_sRGB/LUTs_399.pth",
185
+ path_classifier="saved_models/LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10_sRGB/classifier_399.pth") -> None:
186
+ super(LUT_all,self).__init__()
187
+ self.classifier=Classifier()
188
+ self.classifier.load_state_dict(torch.load(path_classifier))
189
+
190
+ self.LUT0 = Generator3DLUT_identity()
191
+ self.LUT1 = Generator3DLUT_zero()
192
+ self.LUT2 = Generator3DLUT_zero()
193
+ LUTs = torch.load(path_LUT)
194
+ self.LUT0.load_state_dict(LUTs["0"])
195
+ self.LUT1.load_state_dict(LUTs["1"])
196
+ self.LUT2.load_state_dict(LUTs["2"])
197
+ # self.trilinear_ = TrilinearInterpolation()
198
+ # self.trilinear_ = bing_lut_trilinearInterplt()
199
+ self.trilinear_=Tritri()
200
+
201
+ def forward(self,img):
202
+ pred = self.classifier(img).squeeze()
203
+
204
+ # #numpy squeeze方法去掉矩阵中维度为1的维度,返回np.ndarray
205
+ # LUT = pred[0] * self.LUT0.LUT
206
+ LUT = pred[0] * self.LUT0.LUT + pred[1] * self.LUT1.LUT + pred[2] * self.LUT2.LUT
207
+ output = self.trilinear_(LUT, img)
208
+ # _,output = self.trilinear_(LUT, img)
209
+ return output
210
+ # return LUT
211
+
212
+
213
+
214
+ # class TrilinearInterpolationFunction(torch.autograd.Function):
215
+ # @staticmethod
216
+ # def forward(ctx, lut, x):
217
+
218
+ # x = x.contiguous()
219
+
220
+ # output = x.new(x.size())
221
+ # dim = lut.size()[-1]
222
+ # shift = dim ** 3
223
+ # binsize = 1.000001 / (dim-1)
224
+ # W = x.size(2)
225
+ # H = x.size(3)
226
+ # batch = x.size(0)
227
+ # #trilinear这个包是作者自己实现的
228
+ # assert 1 == trilinear.forward(lut,
229
+ # x,
230
+ # output,
231
+ # dim,
232
+ # shift,
233
+ # binsize,
234
+ # W,
235
+ # H,
236
+ # batch)
237
+
238
+ # int_package = torch.IntTensor([dim, shift, W, H, batch])
239
+ # float_package = torch.FloatTensor([binsize])
240
+ # variables = [lut, x, int_package, float_package]
241
+
242
+ # ctx.save_for_backward(*variables)
243
+
244
+ # return lut, output
245
+
246
+ # @staticmethod
247
+ # def backward(ctx, lut_grad, x_grad):
248
+
249
+ # lut, x, int_package, float_package = ctx.saved_variables
250
+ # dim, shift, W, H, batch = int_package
251
+ # dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch)
252
+ # binsize = float(float_package[0])
253
+
254
+ # assert 1 == trilinear.backward(x,
255
+ # x_grad,
256
+ # lut_grad,
257
+ # dim,
258
+ # shift,
259
+ # binsize,
260
+ # W,
261
+ # H,
262
+ # batch)
263
+ # return lut_grad, x_grad
264
+
265
+
266
+ # class TrilinearInterpolation(torch.nn.Module):
267
+ # def __init__(self):
268
+ # super(TrilinearInterpolation, self).__init__()
269
+
270
+ # def forward(self, lut, x):
271
+ # return TrilinearInterpolationFunction.apply(lut, x)
272
+
273
+
274
+ class TV_3D(nn.Module):
275
+ def __init__(self, dim=33):
276
+ super(TV_3D,self).__init__()
277
+
278
+ self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float)
279
+ self.weight_r[:,:,:,(0,dim-2)] *= 2.0
280
+ self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float)
281
+ self.weight_g[:,:,(0,dim-2),:] *= 2.0
282
+ self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float)
283
+ self.weight_b[:,(0,dim-2),:,:] *= 2.0
284
+ self.relu = torch.nn.ReLU()
285
+
286
+ def forward(self, LUT):
287
+
288
+ dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:]
289
+ dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:]
290
+ dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:]
291
+ tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b))
292
+
293
+ mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b))
294
+
295
+ return tv, mn
296
+
297
+
298
+ ##new by bing##
299
+ if __name__=='__main__':
300
+ def img_process_256(img):
301
+ # 将PIL类型的图片文件(mode=RGB size=3840x2160,三通道)转换为tensor,tensor维度是[N,C,H,W](即[1,3,256,256])
302
+ img=img.resize((256,256))
303
+ trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
304
+ img = trans(img)
305
+ img = torch.unsqueeze(img,0) # 填充一维
306
+ print("img",img.size())
307
+ # # 将其由HWC格式改成NCHW格式,N=1
308
+ # img=np.array(img)
309
+ return img
310
+
311
+ def img_process_4k(img):
312
+ # 将PIL类型的图片文件(mode=RGB size=3840x2160,三通道)转换为tensor,tensor维度是[N,C,H,W](即[1,3,256,256])
313
+ trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
314
+ img = trans(img)
315
+ img = torch.unsqueeze(img,0) # 填充一维
316
+ print("img",img.size())
317
+ # # 将其由HWC格式改成NCHW格式,N=1
318
+ # img=np.array(img)
319
+ return img
320
+
321
+
322
+ img_ori=Image.open("/home/elle/bing/proj/code/download-4k-img/picture/%s" % ("X4_Animal2_BIC_g_03.png"))
323
+ img=img_process_256(img_ori)
324
+ img_4k=img_process_4k(img_ori)
325
+ model=LUT_all()
326
+
327
+ out=model(img_4k)
328
+ print(out)
329
+
models/trilinear_test.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from re import A
2
+ import time
3
+ from turtle import width
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ ##new####
10
+ # https://github.com/tedyhabtegebrial/PyTorch-Trilinear-Interpolation
11
+ class TrilinearIntepolation(nn.Module):
12
+ """TrilinearIntepolation in PyTorch."""
13
+
14
+ def __init__(self):
15
+ super(TrilinearIntepolation, self).__init__()
16
+
17
+ def sample_at_integer_locs(self, input_feats, index_tensor):
18
+ assert input_feats.ndimension()==5, 'input_feats should be of shape [Batch,F,D,Height,Width]'
19
+ assert index_tensor.ndimension()==4, 'index_tensor should be of shape [Batch,Height,Width,3]'
20
+ # first sample pixel locations using nearest neighbour interpolation
21
+ batch_size, num_chans, num_d, height, width = input_feats.shape
22
+ grid_height, grid_width = index_tensor.shape[1],index_tensor.shape[2]
23
+
24
+ xy_grid = index_tensor[..., 0:2]
25
+ # 0:2是包括0但是不包括2的,因此取出来的是最后一个维度的0维和1维
26
+ xy_grid[..., 0] = xy_grid[..., 0] - ((width-1.0)/2.0)
27
+ xy_grid[..., 0] = xy_grid[..., 0] / ((width-1.0)/2.0)
28
+ xy_grid[..., 1] = xy_grid[..., 1] - ((height-1.0)/2.0)
29
+ xy_grid[..., 1] = xy_grid[..., 1] / ((height-1.0)/2.0)
30
+ xy_grid = torch.clamp(xy_grid, min=-1.0, max=1.0)
31
+ #clamp限制每个元素的最大值和最小值
32
+ sampled_in_2d = F.grid_sample(input=input_feats.view(batch_size, num_chans*num_d, height, width),
33
+ grid=xy_grid, mode='nearest').view(batch_size, num_chans, num_d, grid_height, grid_width)
34
+ # grid_sample双线性插值https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html?highlight=grid_sample#torch.nn.functional.grid_sample
35
+ # view函数https://blog.csdn.net/york1996/article/details/81949843
36
+ z_grid = index_tensor[..., 2].view(batch_size, 1, 1, grid_height, grid_width)
37
+ z_grid = z_grid.long().clamp(min=0, max=num_d-1)
38
+ # .long()将张量转换为int64类型
39
+ z_grid = z_grid.expand(batch_size,num_chans, 1, grid_height, grid_width)
40
+ # expand对原张量中维度为1的维度进行扩展 https://blog.csdn.net/weixin_42782150/article/details/108615706
41
+ # 本例中是使用expand对dim=1的维度进行扩展,扩展成num_chans
42
+ sampled_in_3d = sampled_in_2d.gather(2, z_grid).squeeze(2)
43
+ return sampled_in_3d
44
+
45
+
46
+ def forward(self, input_feats, sampling_grid):
47
+ assert input_feats.ndimension()==5, 'input_feats should be of shape [B,F,D,H,W]'
48
+ assert sampling_grid.ndimension()==4, 'sampling_grid should be of shape [B,H,W,3]'
49
+ batch_size, num_chans, num_d, height, width = input_feats.shape
50
+ grid_height, grid_width = sampling_grid.shape[1],sampling_grid.shape[2]
51
+ # make sure sampling grid lies between -1, 1
52
+ sampling_grid = torch.clamp(sampling_grid, min=-1.0, max=1.0)
53
+ # map to 0,1
54
+ sampling_grid = (sampling_grid+1)/2.0
55
+ # Scale grid to floating point pixel locations
56
+ scaling_factor = torch.FloatTensor([width-1.0, height-1.0, num_d-1.0]).to(input_feats.device).view(1, 1, 1, 3)
57
+ sampling_grid = scaling_factor*sampling_grid
58
+ # Now sampling grid is between [0, w-1; 0,h-1; 0,d-1]
59
+ x, y, z = torch.split(sampling_grid, split_size_or_sections=1, dim=3)
60
+ #这个(x,y,z)是输入的浮点数(在这篇文章中是每个像素点的rgb值)
61
+ #这个(x0,y0,z0)是输入的浮点数向下取整
62
+ #把sampling_grid维度是3的那个维度切成每份大小为1
63
+ x_0, y_0, z_0 = torch.split(sampling_grid.floor(), split_size_or_sections=1, dim=3)
64
+ x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0
65
+ u, v, w = x-x_0, y-y_0, z-z_0
66
+ print("v:",x_0,y_0,z_0)
67
+ print("s:",x_0.size(),y_0.size(),z_0.size())
68
+ print("size,cat",torch.cat([x_0, y_0, z_0],dim=3).size())
69
+ u, v, w = map(lambda x:x.view(batch_size, 1, grid_height, grid_width).expand(
70
+ batch_size, num_chans, grid_height, grid_width), [u, v, w])
71
+ c_000 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_0, z_0], dim=3))
72
+ # torch.cat 函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。
73
+ c_001 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_0, z_1], dim=3))
74
+ c_010 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_1, z_0], dim=3))
75
+ c_011 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_1, z_1], dim=3))
76
+ c_100 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_0, z_0], dim=3))
77
+ c_101 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_0, z_1], dim=3))
78
+ c_110 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_1, z_0], dim=3))
79
+ c_111 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_1, z_1], dim=3))
80
+ c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
81
+ (1.0-u)*(1.0-v)*(w)*c_001 + \
82
+ (1.0-u)*(v)*(1.0-w)*c_010 + \
83
+ (1.0-u)*(v)*(w)*c_011 + \
84
+ (u)*(1.0-v)*(1.0-w)*c_100 + \
85
+ (u)*(1.0-v)*(w)*c_101 + \
86
+ (u)*(v)*(1.0-w)*c_110 + \
87
+ (u)*(v)*(w)*c_111
88
+ return c_xyz
89
+ # class bing_lut_trilinearInterplt(nn.Module):
90
+
91
+ # def __init__(self):
92
+ # super(bing_lut_trilinearInterplt, self).__init__()
93
+
94
+ # def test(self,LUT,img_input):
95
+ # # batch_size, num_chans, height, width = img_input.shape
96
+ # # grid_height, grid_width = LUT.shape[1],LUT.shape[2]
97
+ # grid_in=img_input.transpose(1,2).transpose(2,3)
98
+ # # 原本img_input NCHW,改成 NHWC
99
+ # xy_grid=grid_in[...,0:2]
100
+ # yz_grid=grid_in[...,1:3]
101
+ # #只取3通道中的第0和第1通道(0:2不含2)
102
+ # input_LUT=LUT[:,:,0,:]
103
+ # input_LUT_ori=input_LUT.squeeze(2)
104
+ # # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了
105
+ # input_LUT=input_LUT_ori[...,0:2]
106
+ # input_LUT2=input_LUT_ori[...,1:]
107
+ # print("input_LUT2.size()",input_LUT2.size())
108
+ # # LUT[33,33,2]
109
+ # input_LUT=input_LUT.transpose(1,2).transpose(0,1)
110
+ # input_LUT2=input_LUT2.transpose(1,2).transpose(0,1)
111
+ # # LUT[2,33,33]
112
+ # input_LUT=input_LUT.unsqueeze(0)
113
+ # input_LUT2=input_LUT2.unsqueeze(0)
114
+ # print(input_LUT.size())
115
+ # print(input_LUT2.size())
116
+ # print(grid_in.size())
117
+ # sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest')
118
+ # # .view(batch_size, num_chans, num_d, grid_height, grid_width)
119
+ # sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest')
120
+ # # .view(batch_size, num_chans, num_d, grid_height, grid_width)
121
+
122
+ # # print("sampled_in_2d.size()",sampled_in_2d.size())
123
+ # # print("sampled_in_2d.size()",sampled_in_2d_2.size())
124
+ # # # [1,2,2160,3840]
125
+ # # print("ss")
126
+ # # print(sampled_in_2d.size())
127
+ # # print(sampled_in_2d_2.size())
128
+ # res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1)
129
+ # print(res.size())
130
+ # return res
131
+ # # z_grid = grid_in[..., 2]
132
+ # # print(z_grid.size())
133
+ # # # [1,2160,3840]
134
+ # # print("sss")
135
+
136
+
137
+
138
+ # def gen_Cout_ijk(self,LUT,x_i,y_i,z_i):
139
+ # # def gen_Cout_ijk(LUT,x_i,y_i,z_i,channel=3):
140
+ # # LUT size [3,33,33,33]
141
+ # # x_i,y_i,z_i size [1,1,2160,3840]
142
+ # # N=batch_size
143
+ # #img_input.size()=[1,3,2160,3840]\
144
+ # # LUT.size()=[3,33,33,33]
145
+ # # assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)'
146
+ # channel=3
147
+ # batch_size,_,height,width=x_i.size()
148
+ # print(batch_size,height,width)
149
+ # output=torch.zeros([batch_size,channel,height,width])
150
+ # # 设置输出大小为[1,3,2160,3840]
151
+ # if batch_size==1:
152
+ # # x_i=x_i.view(height*width)
153
+ # # y_i=y_i.view(height*width)
154
+ # # z_i=z_i.view(height*width)
155
+ # x_i=x_i.view(height*width).long()
156
+ # y_i=y_i.view(height*width).long()
157
+ # z_i=z_i.view(height*width).long()
158
+ # # x_i=x_i.view(1, height*width)
159
+ # # y_i=y_i.view(1, height*width)
160
+ # # z_i=z_i.view(1, height*width)
161
+ # # 2维tensor,[1, 2160*3840]
162
+ # # xyz_i=torch.cat([x_i,y_i,z_i],dim=0)
163
+ # # # xyz_i 2维tensor,[3, 2160*3840]
164
+
165
+ # # print("xyz_i.size()",xyz_i.size())
166
+ # else:
167
+ # print("error:batch size must be 1")
168
+ # for i in range(height*width):
169
+ # h_index=int(i/width)
170
+ # w_index=int(i%width)
171
+ # # print(h_index)
172
+ # # print(w_index)
173
+ # # print(x_i.size())
174
+ # # print(batch_size)
175
+ # # print(output.size())
176
+ # # print(output[0,0,h_index,w_index])
177
+ # if(i%10000==0):
178
+ # print(i)
179
+ # output[batch_size-1,0,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],0]
180
+ # output[batch_size-1,1,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],1]
181
+ # output[batch_size-1,2,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],2]
182
+
183
+ # # x_i=x_i.view(batch_size,height*width)
184
+ # # y_i=y_i.view(batch_size,height*width)
185
+ # # z_i=z_i.view(batch_size,height*width)
186
+ # # 1,2160*3840
187
+
188
+
189
+ # return output
190
+
191
+
192
+ # def forward(self, LUT, img_input):
193
+ # assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]'
194
+ # # N=batch_size
195
+ # #img_input.size()=[1,3,2160,3840]\
196
+ # # LUT.size()=[3,33,33,33]
197
+ # assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)'
198
+ # batch_size, num_chans, height, width = img_input.shape
199
+ # dim = LUT.shape[1] # M
200
+ # img_size=img_input.size()
201
+ # Cmax=255.0
202
+ # s=Cmax/dim
203
+ # r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1)
204
+ # # 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分
205
+ # #r,g,b.size()=[1,1,2160,3840]
206
+ # # r=img_input[:,0,:,:]
207
+ # # g=img_input[:,1,:,:]
208
+ # # b=img_input[:,2,:,:]
209
+ # x=r/s
210
+ # y=g/s
211
+ # z=b/s
212
+ # # tmptmp=self.test(LUT,img_input)
213
+ # # x,y,z.size=[1,1,,2160,3840]
214
+ # # x_0,y_0,z_0.size=[1,1,,2160,3840]
215
+ # # x_1, y_1, z_1.size=[1,1,,2160,3840]
216
+ # x_0,y_0,z_0=x.floor(),y.floor(),z.floor()
217
+ # x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0
218
+ # u, v, w = x-x_0, y-y_0, z-z_0
219
+ # # u,v,w.size=[1,1,2160,3840]
220
+ # # print("x_0.size",x_0.size())
221
+ # c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1))
222
+ # print(c_000.size())
223
+ # # x_i是顶点,大小为[1,1,2160,3840]
224
+ # # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
225
+ # c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1))
226
+ # c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1))
227
+ # c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1))
228
+ # c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1))
229
+ # c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1))
230
+ # c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1))
231
+ # c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1))
232
+
233
+ # # c_000 = self.gen_Cout_ijk(LUT,x_0,y_0,z_0)
234
+ # # # x_i是顶点,大小为[1,1,2160,3840]
235
+ # # # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
236
+ # # c_100 = self.gen_Cout_ijk(LUT,x_1,y_0,z_0)
237
+ # # c_010 = self.gen_Cout_ijk(LUT,x_0,y_1,z_0)
238
+ # # c_110 = self.gen_Cout_ijk(LUT,x_1,y_1,z_0)
239
+ # # c_001 = self.gen_Cout_ijk(LUT,x_0,y_0,z_1)
240
+ # # c_101 = self.gen_Cout_ijk(LUT,x_1,y_0,z_1)
241
+ # # c_011 = self.gen_Cout_ijk(LUT,x_0,y_1,z_1)
242
+ # # c_111 = self.gen_Cout_ijk(LUT,x_1,y_1,z_1)
243
+ # c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
244
+ # (1.0-u)*(1.0-v)*(w)*c_001 + \
245
+ # (1.0-u)*(v)*(1.0-w)*c_010 + \
246
+ # (1.0-u)*(v)*(w)*c_011 + \
247
+ # (u)*(1.0-v)*(1.0-w)*c_100 + \
248
+ # (u)*(1.0-v)*(w)*c_101 + \
249
+ # (u)*(v)*(1.0-w)*c_110 + \
250
+ # (u)*(v)*(w)*c_111
251
+ # # 广播机制,输出[1,3,2160,3840]
252
+ # print("c_xyz",c_xyz.size())
253
+ # return c_xyz
254
+
255
+ # # id100 = x_0 + 1.0 + y_0 * dim + z_0 * dim * dim
256
+ # # id010 = x_0 + (y_0 + 1.0) * dim + z_0 * dim * dim
257
+ # # id110 = x_0 + 1.0 + (y_0 + 1.0) * dim + z_0 * dim * dim
258
+ # # id001 = x_0 + y_0 * dim + (z_0 + 1.0) * dim * dim
259
+ # # id101 = x_0 + 1.0 + y_0 * dim + (z_0 + 1.0) * dim * dim
260
+ # # id011 = x_0 + (y_0 + 1.0) * dim + (z_0 + 1.0) * dim * dim
261
+ # # id111 = x_0 + 1.0 + (y_0 + 1.0) * dim + (z_0 + 1.0) * dim * dim
262
+
263
+ # # w000 = (1.0-u)*(1-v)*(1-w)
264
+ # # #大概也许得改成点乘
265
+ # # w100 = u*(1-v)*(1-w)
266
+ # # w010 = (1-u)*v*(1-w)
267
+ # # w110 = u*v*(1-w)
268
+ # # w001 = (1-u)*(1-v)*w
269
+ # # w101 = u*(1-v)*w
270
+ # # w011 = (1-u)*v*w
271
+ # # w111 = u*v*w
272
+ # # output=
273
+
274
+ # # print("v:",x_0,y_0,z_0)
275
+ # # print("s:",x_0.size(),y_0.size(),z_0.size())
276
+ # # u,v,w=u/s,v/s,w/s
277
+ # # c_000 = self.gen_Cout_ijk(x_0,y_0,z_0)
278
+ # # c_100 = self.gen_Cout_ijk(x_1,y_0,z_0)
279
+ # # c_010 = self.gen_Cout_ijk(x_0,y_1,z_0)
280
+ # # c_110 = self.gen_Cout_ijk(x_1,y_1,z_0)
281
+ # # c_001 = self.gen_Cout_ijk(x_0,y_0,z_1)
282
+ # # c_101 = self.gen_Cout_ijk(x_1,y_0,z_1)
283
+ # # c_011 = self.gen_Cout_ijk(x_0,y_1,z_1)
284
+ # # c_111 = self.gen_Cout_ijk(x_1,y_1,z_1)
285
+
286
+
287
+ # # c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
288
+ # # (1.0-u)*(1.0-v)*(w)*c_001 + \
289
+ # # (1.0-u)*(v)*(1.0-w)*c_010 + \
290
+ # # (1.0-u)*(v)*(w)*c_011 + \
291
+ # # (u)*(1.0-v)*(1.0-w)*c_100 + \
292
+ # # (u)*(1.0-v)*(w)*c_101 + \
293
+ # # (u)*(v)*(1.0-w)*c_110 + \
294
+ # # (u)*(v)*(w)*c_111
295
+ # # return c_xyz
296
+
297
+ class Tritri(nn.Module):
298
+
299
+ def __init__(self):
300
+ super(Tritri, self).__init__()
301
+
302
+ def forward(self,LUT,img):
303
+ img = (img - .5) * 2.
304
+ # grid_sample expects NxDxHxWx3 (1x1xHxWx3)
305
+ img = img.permute(0, 2, 3, 1)[:, None]
306
+ # add batch dim to LUT
307
+ LUT = LUT[None]
308
+ # grid sample
309
+ result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True)
310
+ # drop added dimensions and permute back
311
+ result = result[:, :, 0].permute(0, 2, 3, 1)
312
+ return result
313
+
314
+
315
+
316
+ class bing_lut_trilinearInterplt(nn.Module):
317
+
318
+ def __init__(self):
319
+ super(bing_lut_trilinearInterplt, self).__init__()
320
+
321
+ def test(self,LUT,img_input):
322
+ # batch_size, num_chans, height, width = img_input.shape
323
+ # grid_height, grid_width = LUT.shape[1],LUT.shape[2]
324
+ grid_in=img_input.transpose(1,2).transpose(2,3)
325
+ # 1
326
+ # 原本img_input NCHW,改成 NHWC
327
+ xy_grid=grid_in[...,0:2]
328
+ yz_grid=grid_in[...,1:3]
329
+ # 23
330
+ #只取3通道中的第0和第1通道(0:2不含2)
331
+
332
+ # LUT正确版本应该是[3,33,33,33]
333
+ # 在这里弄错成为[33,33,33,3]
334
+ input_LUT=LUT[:,:,:,0:1]
335
+ input_LUT_ori=input_LUT.squeeze(3)
336
+ # 45
337
+
338
+ # [3,33,33,33]->[3,33,33] 把dim=3的数据丢掉了
339
+
340
+ # input_LUT=LUT[:,:,0,:]
341
+ # input_LUT_ori=input_LUT.squeeze(2)
342
+ # # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了
343
+
344
+ input_LUT=input_LUT_ori[0:2,...]
345
+ input_LUT2=input_LUT_ori[1:,...]
346
+ input_LUT=input_LUT.unsqueeze(0)
347
+ input_LUT2=input_LUT2.unsqueeze(0)
348
+ # 6-9
349
+
350
+ # 都是[1,2,33,33]
351
+ # print(input_LUT.size())
352
+ # print("dtype:")
353
+ # print(input_LUT.dtype)
354
+ # print(input_LUT2.dtype)
355
+ # print(xy_grid.dtype)
356
+ # print(yz_grid.dtype)
357
+ # input_LUT.int()
358
+ # input_LUT2.int()
359
+ # xy_grid.int()
360
+ # yz_grid.int()
361
+
362
+ # # print(grid_in.size())
363
+ sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest',align_corners=False)
364
+ # .view(batch_size, num_chans, num_d, grid_height, grid_width)
365
+ sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest',align_corners=False)
366
+ # .view(batch_size, num_chans, num_d, grid_height, grid_width)
367
+ # 10
368
+ res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1)
369
+ # print(res.size())
370
+ return res
371
+
372
+ def forward(self, LUT, img_input):
373
+ assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]'
374
+ # N=batch_size
375
+ #img_input.size()=[1,3,2160,3840]\
376
+ # LUT.size()=[3,33,33,33]
377
+ assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)'
378
+ # batch_size, num_chans, height, width = img_input.shape
379
+ dim = LUT.shape[1] # M
380
+ # img_size=img_input.size()
381
+ # Cmax=1.00001
382
+ Cmax=10
383
+ s=Cmax/(dim-1.0)
384
+ s=torch.Tensor([s])
385
+ #谢谢小黄鸭!!#data types int64 and int32 do not match in BroadcastRel
386
+
387
+ r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1)
388
+ # 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分
389
+ #r,g,b.size()=[1,1,2160,3840]
390
+ # r=img_input[:,0,:,:]
391
+ # g=img_input[:,1,:,:]
392
+ # b=img_input[:,2,:,:]
393
+ s=s.to(r.device)
394
+ x=r/s
395
+ y=g/s
396
+ z=b/s
397
+ # tmptmp=self.test(LUT,img_input)
398
+ # x,y,z.size=[1,1,,2160,3840]
399
+ # x_0,y_0,z_0.size=[1,1,,2160,3840]
400
+ # x_1, y_1, z_1.size=[1,1,,2160,3840]
401
+ x_0,y_0,z_0=x.floor(),y.floor(),z.floor()
402
+ x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0
403
+ u, v, w = x-x_0, y-y_0, z-z_0
404
+ # u,v,w.size=[1,1,2160,3840]
405
+ # print("x_0.size",x_0.size())
406
+
407
+ c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1))
408
+ # print(c_000.size())
409
+ # x_i是顶点,大小为[1,1,2160,3840]
410
+ # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
411
+ c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1))
412
+ c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1))
413
+ c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1))
414
+ c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1))
415
+ c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1))
416
+ c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1))
417
+ c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1))
418
+
419
+ c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
420
+ (1.0-u)*(1.0-v)*(w)*c_001 + \
421
+ (1.0-u)*(v)*(1.0-w)*c_010 + \
422
+ (1.0-u)*(v)*(w)*c_011 + \
423
+ (u)*(1.0-v)*(1.0-w)*c_100 + \
424
+ (u)*(1.0-v)*(w)*c_101 + \
425
+ (u)*(v)*(1.0-w)*c_110 + \
426
+ (u)*(v)*(w)*c_111
427
+ # 广播机制,输出[1,3,2160,3840]
428
+ print("c_xyz",c_xyz.size())
429
+ return c_xyz
430
+
431
+ class bing_lut_trilinearInterplt_backup(nn.Module):
432
+
433
+ def __init__(self):
434
+ super(bing_lut_trilinearInterplt, self).__init__()
435
+
436
+ def test(self,LUT,img_input):
437
+ # batch_size, num_chans, height, width = img_input.shape
438
+ # grid_height, grid_width = LUT.shape[1],LUT.shape[2]
439
+ grid_in=img_input.transpose(1,2).transpose(2,3)
440
+ # 1
441
+ # 原本img_input NCHW,改成 NHWC
442
+ xy_grid=grid_in[...,0:2]
443
+ yz_grid=grid_in[...,1:3]
444
+ # 23
445
+ #只取3通道中的第0和第1通道(0:2不含2)
446
+
447
+ # LUT正确版本应该是[3,33,33,33]
448
+ # 在这里弄错成为[33,33,33,3]
449
+ input_LUT=LUT[:,:,:,0:1]
450
+ input_LUT_ori=input_LUT.squeeze(3)
451
+ # 45
452
+
453
+ # [3,33,33,33]->[3,33,33] 把dim=3的数据丢掉了
454
+
455
+ # input_LUT=LUT[:,:,0,:]
456
+ # input_LUT_ori=input_LUT.squeeze(2)
457
+ # # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了
458
+
459
+ input_LUT=input_LUT_ori[0:2,...]
460
+ input_LUT2=input_LUT_ori[1:,...]
461
+ input_LUT=input_LUT.unsqueeze(0)
462
+ input_LUT2=input_LUT2.unsqueeze(0)
463
+ # 6-9
464
+
465
+ # 都是[1,2,33,33]
466
+ # print(input_LUT.size())
467
+ # print("dtype:")
468
+ # print(input_LUT.dtype)
469
+ # print(input_LUT2.dtype)
470
+ # print(xy_grid.dtype)
471
+ # print(yz_grid.dtype)
472
+ # input_LUT.int()
473
+ # input_LUT2.int()
474
+ # xy_grid.int()
475
+ # yz_grid.int()
476
+
477
+ # # print(grid_in.size())
478
+ sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest')
479
+ # .view(batch_size, num_chans, num_d, grid_height, grid_width)
480
+ sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest')
481
+ # .view(batch_size, num_chans, num_d, grid_height, grid_width)
482
+ # 10
483
+ res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1)
484
+ # print(res.size())
485
+ return res
486
+
487
+ def forward(self, LUT, img_input):
488
+ assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]'
489
+ # N=batch_size
490
+ #img_input.size()=[1,3,2160,3840]\
491
+ # LUT.size()=[3,33,33,33]
492
+ assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)'
493
+ # batch_size, num_chans, height, width = img_input.shape
494
+ dim = LUT.shape[1] # M
495
+ # img_size=img_input.size()
496
+ Cmax=255.0
497
+ s=Cmax/dim
498
+ s=torch.Tensor([s])
499
+ #谢谢小黄鸭!!#data types int64 and int32 do not match in BroadcastRel
500
+
501
+ r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1)
502
+ # 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分
503
+ #r,g,b.size()=[1,1,2160,3840]
504
+ # r=img_input[:,0,:,:]
505
+ # g=img_input[:,1,:,:]
506
+ # b=img_input[:,2,:,:]
507
+ x=r/s
508
+ y=g/s
509
+ z=b/s
510
+ # tmptmp=self.test(LUT,img_input)
511
+ # x,y,z.size=[1,1,,2160,3840]
512
+ # x_0,y_0,z_0.size=[1,1,,2160,3840]
513
+ # x_1, y_1, z_1.size=[1,1,,2160,3840]
514
+ x_0,y_0,z_0=x.floor(),y.floor(),z.floor()
515
+ x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0
516
+ u, v, w = x-x_0, y-y_0, z-z_0
517
+ # u,v,w.size=[1,1,2160,3840]
518
+ # print("x_0.size",x_0.size())
519
+
520
+ c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1))
521
+ # print(c_000.size())
522
+ # x_i是顶点,大小为[1,1,2160,3840]
523
+ # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
524
+ c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1))
525
+ c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1))
526
+ c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1))
527
+ c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1))
528
+ c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1))
529
+ c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1))
530
+ c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1))
531
+
532
+ # c_000 = self.gen_Cout_ijk(LUT,x_0,y_0,z_0)
533
+ # # x_i是顶点,大小为[1,1,2160,3840]
534
+ # # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
535
+ # c_100 = self.gen_Cout_ijk(LUT,x_1,y_0,z_0)
536
+ # c_010 = self.gen_Cout_ijk(LUT,x_0,y_1,z_0)
537
+ # c_110 = self.gen_Cout_ijk(LUT,x_1,y_1,z_0)
538
+ # c_001 = self.gen_Cout_ijk(LUT,x_0,y_0,z_1)
539
+ # c_101 = self.gen_Cout_ijk(LUT,x_1,y_0,z_1)
540
+ # c_011 = self.gen_Cout_ijk(LUT,x_0,y_1,z_1)
541
+ # c_111 = self.gen_Cout_ijk(LUT,x_1,y_1,z_1)
542
+ c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
543
+ (1.0-u)*(1.0-v)*(w)*c_001 + \
544
+ (1.0-u)*(v)*(1.0-w)*c_010 + \
545
+ (1.0-u)*(v)*(w)*c_011 + \
546
+ (u)*(1.0-v)*(1.0-w)*c_100 + \
547
+ (u)*(1.0-v)*(w)*c_101 + \
548
+ (u)*(v)*(1.0-w)*c_110 + \
549
+ (u)*(v)*(w)*c_111
550
+ # 广播机制,输出[1,3,2160,3840]
551
+ print("c_xyz",c_xyz.size())
552
+ return c_xyz
553
+
554
+
555
+
556
+ # @staticmethod
557
+ # def backward(ctx, lut_grad, x_grad):
558
+
559
+ # lut, x, int_package, float_package = ctx.saved_variables
560
+ # dim, shift, W, H, batch = int_package
561
+ # dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch)
562
+ # binsize = float(float_package[0])
563
+
564
+ # assert 1 == trilinear.backward(x,
565
+ # x_grad,
566
+ # lut_grad,
567
+ # dim,
568
+ # shift,
569
+ # binsize,
570
+ # W,
571
+ # H,
572
+ # batch)
573
+ # return lut_grad, x_grad
574
+
575
+ class Tri(nn.Module):
576
+ def __init__(self):
577
+ super(Tri,self).__init__()
578
+
579
+ if __name__=='__main__':
580
+ # input_features: shape [B, num_channels, depth, height, width]
581
+ # sampling_grid: shape [B,depth, height, 3]
582
+ data = torch.rand(1, 32, 16, 128, 128)
583
+ # data = torch.rand(1, 3, 16, 128, 128)
584
+ sampling_grid = (torch.rand(1, 256, 256, 3) - 0.5)*2.0
585
+ data = data.float().cuda(0)
586
+ sampling_grid = sampling_grid.float().cuda(0)
587
+ trilinear_interpolation = TrilinearIntepolation().cuda(0)
588
+ # LUT.type() torch.cuda.FloatTensor
589
+ # LUT.size() torch.Size([3, 33, 33, 33])
590
+ # img: torch.Size([1, 3, 2160, 3840])
591
+ data2 = torch.rand(1, 3,2160,3840)
592
+ # LUT2 = torch.rand(33,33,33,3)
593
+ LUT2 = torch.rand(3,33,33,33)
594
+
595
+ trilinear_interpolation2 = bing_lut_trilinearInterplt()
596
+ t_start = time.time()
597
+ interp_data2=trilinear_interpolation2(LUT2,data2)
598
+
599
+ # interpolated_data = trilinear_interpolation(data, sampling_grid)
600
+ # print(interpolated_data.shape)
601
+ torch.cuda.synchronize()
602
+ print('time per iteration ', time.time()-t_start)
603
+ # for i in range(100):
604
+ # t_start = time.time()
605
+ # interpolated_data = trilinear_interpolation(data, sampling_grid)
606
+ # print(interpolated_data.shape)
607
+ # torch.cuda.synchronize()
608
+ # print('time per iteration ', time.time()-t_start)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ torch~=1.11.0
2
+ torchvision~=0.12.0
3
+ opencv-python~=4.5.5.64
4
+ pillow~=9.1.1
5
+ numpy~=1.22.3
6
+ scipy~=1.8.1
torchvision_x_functional.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import numbers
3
+ from functools import wraps
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ from scipy.ndimage.filters import gaussian_filter
10
+
11
+ __numpy_type_map = {
12
+ 'float64': torch.DoubleTensor,
13
+ 'float32': torch.FloatTensor,
14
+ 'float16': torch.HalfTensor,
15
+ 'int64': torch.LongTensor,
16
+ 'int32': torch.IntTensor,
17
+ 'int16': torch.ShortTensor,
18
+ 'uint16': torch.ShortTensor,
19
+ 'int8': torch.CharTensor,
20
+ 'uint8': torch.ByteTensor,
21
+ }
22
+
23
+ '''image functional utils
24
+
25
+ '''
26
+
27
+ # NOTE: all the function should recive the ndarray like image, should be W x H x C or W x H
28
+
29
+ # 如果将所有输出的维度够搞成height,width,channel 那么可以不用to_tensor??, 不行
30
+ def preserve_channel_dim(func):
31
+ """Preserve dummy channel dim."""
32
+ @wraps(func)
33
+ def wrapped_function(img, *args, **kwargs):
34
+ shape = img.shape
35
+ result = func(img, *args, **kwargs)
36
+ if len(shape) == 3 and shape[-1] == 1 and len(result.shape) == 2:
37
+ result = np.expand_dims(result, axis=-1)
38
+ return result
39
+
40
+ return wrapped_function
41
+
42
+
43
+ def _is_tensor_image(img):
44
+ return torch.is_tensor(img) and img.ndimension() == 3
45
+
46
+
47
+ def _is_numpy_image(img):
48
+ return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
49
+
50
+
51
+ def to_tensor(img):
52
+ '''convert numpy.ndarray to torch tensor. \n
53
+ if the image is uint8 , it will be divided by 255;\n
54
+ if the image is uint16 , it will be divided by 65535;\n
55
+ if the image is float , it will not be divided, we suppose your image range should between [0~1] ;\n
56
+
57
+ Arguments:
58
+ img {numpy.ndarray} -- image to be converted to tensor.
59
+ '''
60
+ if not _is_numpy_image(img):
61
+ raise TypeError('data should be numpy ndarray. but got {}'.format(type(img)))
62
+
63
+ if img.ndim == 2:
64
+ img = img[:, :, None]
65
+
66
+ if img.dtype == np.uint8:
67
+ img = img.astype(np.float32)/255
68
+ elif img.dtype == np.uint16:
69
+ img = img.astype(np.float32)/65535
70
+ elif img.dtype in [np.float32, np.float64]:
71
+ img = img.astype(np.float32)/1
72
+ else:
73
+ raise TypeError('{} is not support'.format(img.dtype))
74
+
75
+ img = torch.from_numpy(img.transpose((2, 0, 1)))
76
+
77
+ return img
78
+
79
+
80
+ def to_pil_image(tensor):
81
+ # TODO
82
+ pass
83
+
84
+
85
+ def to_tiff_image(tensor):
86
+ # TODO
87
+ pass
88
+
89
+
90
+ def normalize(tensor, mean, std, inplace=False):
91
+ """Normalize a tensor image with mean and standard deviation.
92
+
93
+ .. note::
94
+ This transform acts out of place by default, i.e., it does not mutates the input tensor.
95
+
96
+ See :class:`~torchsat.transforms.Normalize` for more details.
97
+
98
+ Args:
99
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
100
+ mean (sequence): Sequence of means for each channel.
101
+ std (sequence): Sequence of standard deviations for each channel.
102
+
103
+ Returns:
104
+ Tensor: Normalized Tensor image.
105
+ """
106
+ if not _is_tensor_image(tensor):
107
+ raise TypeError('tensor is not a torch image.')
108
+
109
+ if not inplace:
110
+ tensor = tensor.clone()
111
+
112
+ mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
113
+ std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
114
+ tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
115
+ return tensor
116
+
117
+ def noise(img, mode='gaussain', percent=0.02):
118
+ """
119
+ TODO: Not good for uint16 data
120
+ """
121
+ original_dtype = img.dtype
122
+ if mode == 'gaussian':
123
+ mean = 0
124
+ var = 0.1
125
+ sigma = var*0.5
126
+
127
+ if img.ndim == 2:
128
+ h, w = img.shape
129
+ gauss = np.random.normal(mean, sigma, (h, w))
130
+ else:
131
+ h, w, c = img.shape
132
+ gauss = np.random.normal(mean, sigma, (h, w, c))
133
+
134
+ if img.dtype not in [np.float32, np.float64]:
135
+ gauss = gauss * np.iinfo(img.dtype).max
136
+ img = np.clip(img.astype(np.float) + gauss, 0, np.iinfo(img.dtype).max)
137
+ else:
138
+ img = np.clip(img.astype(np.float) + gauss, 0, 1)
139
+
140
+ elif mode == 'salt':
141
+ print(img.dtype)
142
+ s_vs_p = 1
143
+ num_salt = np.ceil(percent * img.size * s_vs_p)
144
+ coords = tuple([np.random.randint(0, i - 1, int(num_salt)) for i in img.shape])
145
+
146
+ if img.dtype in [np.float32, np.float64]:
147
+ img[coords] = 1
148
+ else:
149
+ img[coords] = np.iinfo(img.dtype).max
150
+ print(img.dtype)
151
+ elif mode == 'pepper':
152
+ s_vs_p = 0
153
+ num_pepper = np.ceil(percent * img.size * (1. - s_vs_p))
154
+ coords = tuple([np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape])
155
+ img[coords] = 0
156
+
157
+ elif mode == 's&p':
158
+ s_vs_p = 0.5
159
+
160
+ # Salt mode
161
+ num_salt = np.ceil(percent * img.size * s_vs_p)
162
+ coords = tuple([np.random.randint(0, i - 1, int(num_salt)) for i in img.shape])
163
+ if img.dtype in [np.float32, np.float64]:
164
+ img[coords] = 1
165
+ else:
166
+ img[coords] = np.iinfo(img.dtype).max
167
+
168
+ # Pepper mode
169
+ num_pepper = np.ceil(percent* img.size * (1. - s_vs_p))
170
+ coords = tuple([np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape])
171
+ img[coords] = 0
172
+ else:
173
+ raise ValueError('not support mode for {}'.format(mode))
174
+
175
+ noisy = img.astype(original_dtype)
176
+
177
+ return noisy
178
+
179
+
180
+ def gaussian_blur(img, kernel_size):
181
+ # When sigma=0, it is computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`
182
+ return cv2.GaussianBlur(img, (kernel_size, kernel_size), sigmaX=0)
183
+
184
+
185
+ def adjust_brightness(img, value=0):
186
+ if img.dtype in [np.float, np.float32, np.float64, np.float128]:
187
+ dtype_min, dtype_max = 0, 1
188
+ dtype = np.float32
189
+ else:
190
+ dtype_min = np.iinfo(img.dtype).min
191
+ dtype_max = np.iinfo(img.dtype).max
192
+ dtype = np.iinfo(img.dtype)
193
+
194
+ result = np.clip(img.astype(np.float)+value, dtype_min, dtype_max).astype(dtype)
195
+
196
+ return result
197
+
198
+
199
+ def adjust_contrast(img, factor):
200
+ if img.dtype in [np.float, np.float32, np.float64, np.float128]:
201
+ dtype_min, dtype_max = 0, 1
202
+ dtype = np.float32
203
+ else:
204
+ dtype_min = np.iinfo(img.dtype).min
205
+ dtype_max = np.iinfo(img.dtype).max
206
+ dtype = np.iinfo(img.dtype)
207
+
208
+ result = np.clip(img.astype(np.float)*factor, dtype_min, dtype_max).astype(dtype)
209
+
210
+ return result
211
+
212
+ def adjust_saturation():
213
+ # TODO
214
+ pass
215
+
216
+ def adjust_hue():
217
+ # TODO
218
+ pass
219
+
220
+
221
+
222
+ def to_grayscale(img, output_channels=1):
223
+ """convert input ndarray image to gray sacle image.
224
+
225
+ Arguments:
226
+ img {ndarray} -- the input ndarray image
227
+
228
+ Keyword Arguments:
229
+ output_channels {int} -- output gray image channel (default: {1})
230
+
231
+ Returns:
232
+ ndarray -- gray scale ndarray image
233
+ """
234
+ if img.ndim == 2:
235
+ gray_img = img
236
+ elif img.shape[2] == 3:
237
+ gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
238
+ else:
239
+ gray_img = np.mean(img, axis=2)
240
+ gray_img = gray_img.astype(img.dtype)
241
+
242
+ if output_channels != 1:
243
+ gray_img = np.tile(gray_img, (output_channels, 1, 1))
244
+ gray_img = np.transpose(gray_img, [1,2,0])
245
+
246
+ return gray_img
247
+
248
+
249
+ def shift(img, top, left):
250
+ (h, w) = img.shape[0:2]
251
+ matrix = np.float32([[1, 0, left], [0, 1, top]])
252
+ dst = cv2.warpAffine(img, matrix, (w, h))
253
+
254
+ return dst
255
+
256
+
257
+ def rotate(img, angle, center=None, scale=1.0):
258
+ (h, w) = img.shape[:2]
259
+
260
+ if center is None:
261
+ center = (w / 2, h / 2)
262
+
263
+ M = cv2.getRotationMatrix2D(center, angle, scale)
264
+ rotated = cv2.warpAffine(img, M, (w, h))
265
+
266
+ return rotated
267
+
268
+
269
+ def resize(img, size, interpolation=Image.BILINEAR):
270
+ '''resize the image
271
+ TODO: opencv resize 之后图像就成了0~1了
272
+ Arguments:
273
+ img {ndarray} -- the input ndarray image
274
+ size {int, iterable} -- the target size, if size is intger, width and height will be resized to same \
275
+ otherwise, the size should be tuple (height, width) or list [height, width]
276
+
277
+
278
+ Keyword Arguments:
279
+ interpolation {Image} -- the interpolation method (default: {Image.BILINEAR})
280
+
281
+ Raises:
282
+ TypeError -- img should be ndarray
283
+ ValueError -- size should be intger or iterable vaiable and length should be 2.
284
+
285
+ Returns:
286
+ img -- resize ndarray image
287
+ '''
288
+
289
+ if not _is_numpy_image(img):
290
+ raise TypeError('img shoud be ndarray image [w, h, c] or [w, h], but got {}'.format(type(img)))
291
+ if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size)==2)):
292
+ raise ValueError('size should be intger or iterable vaiable(length is 2), but got {}'.format(type(size)))
293
+
294
+ if isinstance(size, int):
295
+ height, width = (size, size)
296
+ else:
297
+ height, width = (size[0], size[1])
298
+
299
+ return cv2.resize(img, (width, height), interpolation=interpolation)
300
+
301
+
302
+ def pad(img, padding, fill=0, padding_mode='constant'):
303
+ if isinstance(padding, int):
304
+ pad_left = pad_right = pad_top = pad_bottom = padding
305
+ if isinstance(padding, collections.Iterable) and len(padding) == 2:
306
+ pad_left = pad_right = padding[0]
307
+ pad_bottom = pad_top = padding[1]
308
+ if isinstance(padding, collections.Iterable) and len(padding) == 4:
309
+ pad_left = padding[0]
310
+ pad_top = padding[1]
311
+ pad_right = padding[2]
312
+ pad_bottom = padding[3]
313
+
314
+ if img.ndim == 2:
315
+ if padding_mode == 'constant':
316
+ img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode, constant_values=fill)
317
+ else:
318
+ img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
319
+ if img.ndim == 3:
320
+ if padding_mode == 'constant':
321
+ img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode=padding_mode, constant_values=fill)
322
+ else:
323
+ img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode=padding_mode)
324
+ return img
325
+
326
+
327
+ def crop(img, top, left, height, width):
328
+ '''crop image
329
+
330
+ Arguments:
331
+ img {ndarray} -- image to be croped
332
+ top {int} -- top size
333
+ left {int} -- left size
334
+ height {int} -- croped height
335
+ width {int} -- croped width
336
+ '''
337
+ if not _is_numpy_image(img):
338
+ raise TypeError('the input image should be numpy ndarray with dimension 2 or 3.'
339
+ 'but got {}'.format(type(img))
340
+ )
341
+
342
+ if width<0 or height<0 or left <0 or height<0:
343
+ raise ValueError('the input left, top, width, height should be greater than 0'
344
+ 'but got left={}, top={} width={} height={}'.format(left, top, width, height)
345
+ )
346
+ if img.ndim == 2:
347
+ img_height, img_width = img.shape
348
+ else:
349
+ img_height, img_width, _ = img.shape
350
+ if (left+width) > img_width or (top+height) > img_height:
351
+ raise ValueError('the input crop width and height should be small or \
352
+ equal to image width and height. ')
353
+
354
+ if img.ndim == 2:
355
+ return img[top:(top+height), left:(left+width)]
356
+ elif img.ndim == 3:
357
+ return img[top:(top+height), left:(left+width), :]
358
+
359
+
360
+ def center_crop(img, output_size):
361
+ '''crop image
362
+
363
+ Arguments:
364
+ img {ndarray} -- input image
365
+ output_size {number or sequence} -- the output image size. if sequence, should be [h, w]
366
+
367
+ Raises:
368
+ ValueError -- the input image is large than original image.
369
+
370
+ Returns:
371
+ ndarray image -- return croped ndarray image.
372
+ '''
373
+ if img.ndim == 2:
374
+ img_height, img_width = img.shape
375
+ else:
376
+ img_height, img_width, _ = img.shape
377
+
378
+ if isinstance(output_size, numbers.Number):
379
+ output_size = (int(output_size), int(output_size))
380
+ if output_size[0] > img_height or output_size[1] > img_width:
381
+ raise ValueError('the output_size should not greater than image size, but got {}'.format(output_size))
382
+
383
+ target_height, target_width = output_size
384
+
385
+ top = int(round((img_height - target_height)/2))
386
+ left = int(round((img_width - target_width)/2))
387
+
388
+ return crop(img, top, left, target_height, target_width)
389
+
390
+
391
+ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR):
392
+
393
+ img = crop(img, top, left, height, width)
394
+ img = resize(img, size, interpolation)
395
+ return img
396
+
397
+ def vflip(img):
398
+ return cv2.flip(img, 0)
399
+
400
+ def hflip(img):
401
+ return cv2.flip(img, 1)
402
+
403
+ def flip(img, flip_code):
404
+ return cv2.flip(img, flip_code)
405
+
406
+
407
+ def elastic_transform(image, alpha, sigma, alpha_affine, interpolation=cv2.INTER_LINEAR,
408
+ border_mode=cv2.BORDER_REFLECT_101, random_state=None, approximate=False):
409
+ """Elastic deformation of images as described in [Simard2003]_ (with modifications).
410
+ Based on https://gist.github.com/erniejunior/601cdf56d2b424757de5
411
+ .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
412
+ Convolutional Neural Networks applied to Visual Document Analysis", in
413
+ Proc. of the International Conference on Document Analysis and
414
+ Recognition, 2003.
415
+ """
416
+ if random_state is None:
417
+ random_state = np.random.RandomState(1234)
418
+
419
+ height, width = image.shape[:2]
420
+
421
+ # Random affine
422
+ center_square = np.float32((height, width)) // 2
423
+ square_size = min((height, width)) // 3
424
+ alpha = float(alpha)
425
+ sigma = float(sigma)
426
+ alpha_affine = float(alpha_affine)
427
+
428
+ pts1 = np.float32([center_square + square_size, [center_square[0] + square_size, center_square[1] - square_size],
429
+ center_square - square_size])
430
+ pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32)
431
+ matrix = cv2.getAffineTransform(pts1, pts2)
432
+
433
+ image = cv2.warpAffine(image, matrix, (width, height), flags=interpolation, borderMode=border_mode)
434
+
435
+ if approximate:
436
+ # Approximate computation smooth displacement map with a large enough kernel.
437
+ # On large images (512+) this is approximately 2X times faster
438
+ dx = (random_state.rand(height, width).astype(np.float32) * 2 - 1)
439
+ cv2.GaussianBlur(dx, (17, 17), sigma, dst=dx)
440
+ dx *= alpha
441
+
442
+ dy = (random_state.rand(height, width).astype(np.float32) * 2 - 1)
443
+ cv2.GaussianBlur(dy, (17, 17), sigma, dst=dy)
444
+ dy *= alpha
445
+ else:
446
+ dx = np.float32(gaussian_filter((random_state.rand(height, width) * 2 - 1), sigma) * alpha)
447
+ dy = np.float32(gaussian_filter((random_state.rand(height, width) * 2 - 1), sigma) * alpha)
448
+
449
+ x, y = np.meshgrid(np.arange(width), np.arange(height))
450
+
451
+ mapx = np.float32(x + dx)
452
+ mapy = np.float32(y + dy)
453
+
454
+ return cv2.remap(image, mapx, mapy, interpolation, borderMode=border_mode)
455
+
456
+
457
+ def bbox_shift(bboxes, top, left):
458
+ pass
459
+
460
+
461
+ def bbox_vflip(bboxes, img_height):
462
+ """vertical flip the bboxes
463
+ ...........
464
+ . .
465
+ . .
466
+ >...........<
467
+ . .
468
+ . .
469
+ ...........
470
+ Args:
471
+ bbox (ndarray): bbox ndarray [box_nums, 4]
472
+ flip_code (int, optional): [description]. Defaults to 0.
473
+ """
474
+ flipped = bboxes.copy()
475
+ flipped[...,1::2] = img_height - bboxes[...,1::2]
476
+ flipped = flipped[..., [0, 3, 2, 1]]
477
+ return flipped
478
+
479
+
480
+ def bbox_hflip(bboxes, img_width):
481
+ """horizontal flip the bboxes
482
+ ^
483
+ .............
484
+ . . .
485
+ . . .
486
+ . . .
487
+ . . .
488
+ .............
489
+ ^
490
+ Args:
491
+ bbox (ndarray): bbox ndarray [box_nums, 4]
492
+ flip_code (int, optional): [description]. Defaults to 0.
493
+ """
494
+ flipped = bboxes.copy()
495
+ flipped[..., 0::2] = img_width - bboxes[...,0::2]
496
+ flipped = flipped[..., [2, 1, 0, 3]]
497
+ return flipped
498
+
499
+
500
+ def bbox_resize(bboxes, img_size, target_size):
501
+ """resize the bbox
502
+
503
+ Args:
504
+ bboxes (ndarray): bbox ndarray [box_nums, 4]
505
+ img_size (tuple): the image height and width
506
+ target_size (int, or tuple): the target bbox size.
507
+ Int or Tuple, if tuple the shape should be (height, width)
508
+ """
509
+ if isinstance(target_size, numbers.Number):
510
+ target_size = (target_size, target_size)
511
+
512
+ ratio_height = target_size[0]/img_size[0]
513
+ ratio_width = target_size[1]/img_size[1]
514
+
515
+ return bboxes[...,]*[ratio_width,ratio_height,ratio_width,ratio_height]
516
+
517
+
518
+ def bbox_crop(bboxes, top, left, height, width):
519
+ '''crop bbox
520
+
521
+ Arguments:
522
+ img {ndarray} -- image to be croped
523
+ top {int} -- top size
524
+ left {int} -- left size
525
+ height {int} -- croped height
526
+ width {int} -- croped width
527
+ '''
528
+ croped_bboxes = bboxes.copy()
529
+
530
+ right = width + left
531
+ bottom = height + top
532
+
533
+ croped_bboxes[..., 0::2] = bboxes[..., 0::2].clip(left, right) - left
534
+ croped_bboxes[..., 1::2] = bboxes[..., 1::2].clip(top, bottom) - top
535
+
536
+ return croped_bboxes
537
+
538
+ def bbox_pad(bboxes, padding):
539
+ if isinstance(padding, int):
540
+ pad_left = pad_right = pad_top = pad_bottom = padding
541
+ if isinstance(padding, collections.Iterable) and len(padding) == 2:
542
+ pad_left = pad_right = padding[0]
543
+ pad_bottom = pad_top = padding[1]
544
+ if isinstance(padding, collections.Iterable) and len(padding) == 4:
545
+ pad_left = padding[0]
546
+ pad_top = padding[1]
547
+ pad_right = padding[2]
548
+ pad_bottom = padding[3]
549
+
550
+ pad_bboxes = bboxes.copy()
551
+ pad_bboxes[..., 0::2] = bboxes[..., 0::2] + pad_left
552
+ pad_bboxes[..., 1::2] = bboxes[..., 1::2] + pad_top
553
+
554
+ return pad_bboxes