Spaces:
Runtime error
Runtime error
chenzhicun
commited on
Commit
•
ec08fea
1
Parent(s):
b0cf94d
初始化web demo.
Browse files- IdentityLUT33.txt +0 -0
- IdentityLUT64.txt +0 -0
- app.py +119 -0
- examples/example.jpg +0 -0
- models/models_x.py +329 -0
- models/trilinear_test.py +608 -0
- requirements.txt +6 -0
- torchvision_x_functional.py +554 -0
IdentityLUT33.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
IdentityLUT64.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
app.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
from models.models_x import *
|
6 |
+
import torchvision_x_functional as TF_x
|
7 |
+
import torchvision.transforms.functional as TF
|
8 |
+
from torchvision import transforms
|
9 |
+
import cv2
|
10 |
+
from timm.models.hub import download_cached_file
|
11 |
+
|
12 |
+
|
13 |
+
cuda = True if torch.cuda.is_available() else False
|
14 |
+
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
|
15 |
+
trans = transforms.ToTensor()
|
16 |
+
|
17 |
+
|
18 |
+
LUT0 = Generator3DLUT_identity()
|
19 |
+
LUT1 = Generator3DLUT_zero()
|
20 |
+
LUT2 = Generator3DLUT_zero()
|
21 |
+
classifier = Classifier()
|
22 |
+
trilinear_ = Tritri()
|
23 |
+
if cuda:
|
24 |
+
LUT0 = LUT0.cuda()
|
25 |
+
LUT1 = LUT1.cuda()
|
26 |
+
LUT2 = LUT2.cuda()
|
27 |
+
classifier = classifier.cuda()
|
28 |
+
|
29 |
+
# Load pretrained models
|
30 |
+
cache = download_cached_file('https://drive.google.com/uc?export=download&id=1tzeECo1m4MBqvfLv4H4SQ7by4YMEP17H',
|
31 |
+
check_hash=False, progress=True)
|
32 |
+
LUTs = torch.load(cache, map_location=torch.device('cpu'))
|
33 |
+
LUT0.load_state_dict(LUTs["0"])
|
34 |
+
LUT1.load_state_dict(LUTs["1"])
|
35 |
+
LUT2.load_state_dict(LUTs["2"])
|
36 |
+
LUT0.eval()
|
37 |
+
LUT1.eval()
|
38 |
+
LUT2.eval()
|
39 |
+
|
40 |
+
cache = download_cached_file('https://drive.google.com/uc?export=download&id=1rQ_p3NMRFxZ52MOYj0jPewYtD3JQTJGi',
|
41 |
+
check_hash=False, progress=True)
|
42 |
+
classifier.load_state_dict(torch.load(cache, map_location=torch.device('cpu')))
|
43 |
+
classifier.eval()
|
44 |
+
|
45 |
+
|
46 |
+
XLUT0 = Generator3DLUT_identity()
|
47 |
+
XLUT1 = Generator3DLUT_zero()
|
48 |
+
XLUT2 = Generator3DLUT_zero()
|
49 |
+
Xclassifier = Classifier()
|
50 |
+
Xtrilinear_ = Tritri()
|
51 |
+
if cuda:
|
52 |
+
XLUT0 = XLUT0.cuda()
|
53 |
+
XLUT1 = XLUT1.cuda()
|
54 |
+
XLUT2 = XLUT2.cuda()
|
55 |
+
Xclassifier = Xclassifier.cuda()
|
56 |
+
|
57 |
+
# Load pretrained models
|
58 |
+
cache = download_cached_file('https://drive.google.com/uc?export=download&id=1ossTzgbgpZL4Jy5uhiRJDGfCWw9vOv0c',
|
59 |
+
check_hash=False, progress=True)
|
60 |
+
XLUTs = torch.load(cache, map_location=torch.device('cpu'))
|
61 |
+
XLUT0.load_state_dict(XLUTs["0"])
|
62 |
+
XLUT1.load_state_dict(XLUTs["1"])
|
63 |
+
XLUT2.load_state_dict(XLUTs["2"])
|
64 |
+
XLUT0.eval()
|
65 |
+
XLUT1.eval()
|
66 |
+
XLUT2.eval()
|
67 |
+
|
68 |
+
cache = download_cached_file('https://drive.google.com/uc?export=download&id=1279CoaqQZK-eK83283MERoRxtRbIgRew',
|
69 |
+
check_hash=False, progress=True)
|
70 |
+
Xclassifier.load_state_dict(torch.load(cache, map_location=torch.device('cpu')))
|
71 |
+
Xclassifier.eval()
|
72 |
+
|
73 |
+
|
74 |
+
def generate_LUT(img):
|
75 |
+
pred = classifier(img).squeeze()
|
76 |
+
|
77 |
+
LUT = pred[0] * LUT0.LUT + pred[1] * LUT1.LUT + pred[2] * LUT2.LUT # + pred[3] * LUT3.LUT + pred[4] * LUT4.LUT
|
78 |
+
|
79 |
+
return LUT
|
80 |
+
|
81 |
+
def generate_XLUT(img):
|
82 |
+
pred = Xclassifier(img).squeeze()
|
83 |
+
|
84 |
+
XLUT = pred[0] * XLUT0.LUT + pred[1] * XLUT1.LUT + pred[2] * XLUT2.LUT # + pred[3] * LUT3.LUT + pred[4] * LUT4.LUT
|
85 |
+
|
86 |
+
return XLUT
|
87 |
+
|
88 |
+
|
89 |
+
def inference(ori_image, models_n):
|
90 |
+
with torch.no_grad():
|
91 |
+
if models_n == 'sRGB':
|
92 |
+
# img = Image.open(ori_image)
|
93 |
+
# img = TF.to_tensor(img).type(Tensor)
|
94 |
+
img = trans(ori_image)
|
95 |
+
img = img.unsqueeze(0)
|
96 |
+
LUT = generate_LUT(img)
|
97 |
+
result = trilinear_(LUT, img)
|
98 |
+
result = result.permute(0, 3, 1, 2)
|
99 |
+
ndarr = result.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
100 |
+
im = Image.fromarray(ndarr)
|
101 |
+
elif models_n == 'XYZ':
|
102 |
+
img = trans(ori_image)
|
103 |
+
img = img.unsqueeze(0)
|
104 |
+
XLUT = generate_XLUT(img)
|
105 |
+
result = Xtrilinear_(XLUT, img)
|
106 |
+
result = result.permute(0, 3, 1, 2)
|
107 |
+
ndarr = result.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
108 |
+
im = Image.fromarray(ndarr)
|
109 |
+
return im
|
110 |
+
|
111 |
+
|
112 |
+
inputs = [gr.inputs.Image(type='pil', label='待增强图片'),
|
113 |
+
gr.inputs.Radio(choices=['sRGB', 'XYZ'], type="value", default="sRGB", label="图片色彩空间")]
|
114 |
+
outputs = [gr.outputs.Image(type='pil', label='增强后图片')]
|
115 |
+
|
116 |
+
title = '基于LUT的图像增强演示'
|
117 |
+
|
118 |
+
gr.Interface(inference, inputs, outputs, title=title, allow_flagging= 'never',
|
119 |
+
examples=[['./examples/example.jpg', 'sRGB']]).launch(enable_queue=True)
|
examples/example.jpg
ADDED
models/models_x.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from doctest import OutputChecker
|
2 |
+
from turtle import forward
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision.models as models
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from torch.autograd import Variable
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import math
|
11 |
+
|
12 |
+
from models.trilinear_test import bing_lut_trilinearInterplt,Tritri
|
13 |
+
|
14 |
+
from re import I
|
15 |
+
import time
|
16 |
+
from PIL import Image
|
17 |
+
###########################################
|
18 |
+
# use this module for pytorch 1.x,together with trilinear_cpp
|
19 |
+
###########################################
|
20 |
+
|
21 |
+
|
22 |
+
def weights_init_normal_classifier(m):
|
23 |
+
classname = m.__class__.__name__
|
24 |
+
if classname.find("Conv") != -1:
|
25 |
+
torch.nn.init.xavier_normal_(m.weight.data)
|
26 |
+
|
27 |
+
elif classname.find("BatchNorm2d") != -1 or classname.find("InstanceNorm2d") != -1:
|
28 |
+
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
29 |
+
torch.nn.init.constant_(m.bias.data, 0.0)
|
30 |
+
|
31 |
+
class resnet18_224(nn.Module):
|
32 |
+
|
33 |
+
def __init__(self, out_dim=5, aug_test=False):
|
34 |
+
super(resnet18_224, self).__init__()
|
35 |
+
|
36 |
+
self.aug_test = aug_test
|
37 |
+
net = models.resnet18(pretrained=True)
|
38 |
+
# self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
|
39 |
+
# self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
|
40 |
+
|
41 |
+
self.upsample = nn.Upsample(size=(224,224),mode='bilinear')
|
42 |
+
net.fc = nn.Linear(512, out_dim)
|
43 |
+
self.model = net
|
44 |
+
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
|
48 |
+
x = self.upsample(x)
|
49 |
+
if self.aug_test:
|
50 |
+
# x = torch.cat((x, torch.rot90(x, 1, [2, 3]), torch.rot90(x, 3, [2, 3])), 0)
|
51 |
+
x = torch.cat((x, torch.flip(x, [3])), 0)
|
52 |
+
f = self.model(x)
|
53 |
+
|
54 |
+
return f
|
55 |
+
|
56 |
+
##############################
|
57 |
+
# Discriminator
|
58 |
+
##############################
|
59 |
+
|
60 |
+
|
61 |
+
def discriminator_block(in_filters, out_filters, normalization=False):
|
62 |
+
"""Returns downsampling layers of each discriminator block"""
|
63 |
+
layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]
|
64 |
+
layers.append(nn.LeakyReLU(0.2))
|
65 |
+
if normalization:
|
66 |
+
layers.append(nn.InstanceNorm2d(out_filters, affine=True))
|
67 |
+
#layers.append(nn.BatchNorm2d(out_filters))
|
68 |
+
|
69 |
+
return layers
|
70 |
+
|
71 |
+
class Discriminator(nn.Module):
|
72 |
+
def __init__(self, in_channels=3):
|
73 |
+
super(Discriminator, self).__init__()
|
74 |
+
|
75 |
+
self.model = nn.Sequential(
|
76 |
+
nn.Upsample(size=(256,256),mode='bilinear'),
|
77 |
+
nn.Conv2d(3, 16, 3, stride=2, padding=1),
|
78 |
+
nn.LeakyReLU(0.2),
|
79 |
+
nn.InstanceNorm2d(16, affine=True),
|
80 |
+
*discriminator_block(16, 32),
|
81 |
+
*discriminator_block(32, 64),
|
82 |
+
*discriminator_block(64, 128),
|
83 |
+
*discriminator_block(128, 128),
|
84 |
+
#*discriminator_block(128, 128),
|
85 |
+
nn.Conv2d(128, 1, 8, padding=0)
|
86 |
+
)
|
87 |
+
|
88 |
+
def forward(self, img_input):
|
89 |
+
return self.model(img_input)
|
90 |
+
|
91 |
+
class Classifier(nn.Module):
|
92 |
+
def __init__(self, in_channels=3):
|
93 |
+
super(Classifier, self).__init__()
|
94 |
+
|
95 |
+
self.model = nn.Sequential(
|
96 |
+
# nn.Downsample(size=(256,256),mode='bilinear'),
|
97 |
+
nn.Upsample(size=(256,256),mode='bilinear'), #original
|
98 |
+
|
99 |
+
nn.Conv2d(3, 16, 3, stride=2, padding=1),
|
100 |
+
nn.LeakyReLU(0.2),
|
101 |
+
nn.InstanceNorm2d(16, affine=True),
|
102 |
+
*discriminator_block(16, 32, normalization=True),
|
103 |
+
*discriminator_block(32, 64, normalization=True),
|
104 |
+
*discriminator_block(64, 128, normalization=True),
|
105 |
+
*discriminator_block(128, 128),
|
106 |
+
#*discriminator_block(128, 128, normalization=True),
|
107 |
+
nn.Dropout(p=0.5),
|
108 |
+
nn.Conv2d(128, 3, 8, padding=0),
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
def forward(self, img_input):
|
113 |
+
return self.model(img_input)
|
114 |
+
|
115 |
+
|
116 |
+
class Classifier_unpaired(nn.Module):
|
117 |
+
def __init__(self, in_channels=3):
|
118 |
+
super(Classifier_unpaired, self).__init__()
|
119 |
+
|
120 |
+
self.model = nn.Sequential(
|
121 |
+
nn.Upsample(size=(256,256),mode='bilinear'),
|
122 |
+
nn.Conv2d(3, 16, 3, stride=2, padding=1),
|
123 |
+
nn.LeakyReLU(0.2),
|
124 |
+
nn.InstanceNorm2d(16, affine=True),
|
125 |
+
*discriminator_block(16, 32),
|
126 |
+
*discriminator_block(32, 64),
|
127 |
+
*discriminator_block(64, 128),
|
128 |
+
*discriminator_block(128, 128),
|
129 |
+
#*discriminator_block(128, 128),
|
130 |
+
nn.Conv2d(128, 3, 8, padding=0),
|
131 |
+
)
|
132 |
+
|
133 |
+
def forward(self, img_input):
|
134 |
+
return self.model(img_input)
|
135 |
+
|
136 |
+
|
137 |
+
class Generator3DLUT_identity(nn.Module):
|
138 |
+
def __init__(self, dim=33):
|
139 |
+
super(Generator3DLUT_identity, self).__init__()
|
140 |
+
if dim == 33:
|
141 |
+
file = open("IdentityLUT33.txt", 'r')
|
142 |
+
elif dim == 64:
|
143 |
+
file = open("IdentityLUT64.txt", 'r')
|
144 |
+
lines = file.readlines()
|
145 |
+
buffer = np.zeros((3,dim,dim,dim), dtype=np.float32)
|
146 |
+
|
147 |
+
for i in range(0,dim):
|
148 |
+
for j in range(0,dim):
|
149 |
+
for k in range(0,dim):
|
150 |
+
n = i * dim*dim + j * dim + k
|
151 |
+
x = lines[n].split()
|
152 |
+
buffer[0,i,j,k] = float(x[0])
|
153 |
+
buffer[1,i,j,k] = float(x[1])
|
154 |
+
buffer[2,i,j,k] = float(x[2])
|
155 |
+
self.LUT = nn.Parameter(torch.from_numpy(buffer).requires_grad_(True))
|
156 |
+
self.TrilinearInterpolation = Tritri()
|
157 |
+
# self.trilinearItp = bing_lut_trilinearInterplt()
|
158 |
+
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
_, output = self.TrilinearInterpolation(self.LUT, x)
|
162 |
+
# output = self.trilinearItp(self.LUT,x)
|
163 |
+
|
164 |
+
#self.LUT, output = self.TrilinearInterpolation(self.LUT, x)
|
165 |
+
return output
|
166 |
+
|
167 |
+
class Generator3DLUT_zero(nn.Module):
|
168 |
+
def __init__(self, dim=33):
|
169 |
+
super(Generator3DLUT_zero, self).__init__()
|
170 |
+
|
171 |
+
self.LUT = torch.zeros(3,dim,dim,dim, dtype=torch.float)
|
172 |
+
self.LUT = nn.Parameter(torch.tensor(self.LUT))
|
173 |
+
self.TrilinearInterpolation = Tritri()
|
174 |
+
# self.trilinearItp = bing_lut_trilinearInterplt()
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
_, output = self.TrilinearInterpolation(self.LUT, x)
|
178 |
+
# output = self.trilinearItp(self.LUT,x)
|
179 |
+
|
180 |
+
return output
|
181 |
+
|
182 |
+
class LUT_all(nn.Module):
|
183 |
+
def __init__(self,
|
184 |
+
path_LUT="saved_models/LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10_sRGB/LUTs_399.pth",
|
185 |
+
path_classifier="saved_models/LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10_sRGB/classifier_399.pth") -> None:
|
186 |
+
super(LUT_all,self).__init__()
|
187 |
+
self.classifier=Classifier()
|
188 |
+
self.classifier.load_state_dict(torch.load(path_classifier))
|
189 |
+
|
190 |
+
self.LUT0 = Generator3DLUT_identity()
|
191 |
+
self.LUT1 = Generator3DLUT_zero()
|
192 |
+
self.LUT2 = Generator3DLUT_zero()
|
193 |
+
LUTs = torch.load(path_LUT)
|
194 |
+
self.LUT0.load_state_dict(LUTs["0"])
|
195 |
+
self.LUT1.load_state_dict(LUTs["1"])
|
196 |
+
self.LUT2.load_state_dict(LUTs["2"])
|
197 |
+
# self.trilinear_ = TrilinearInterpolation()
|
198 |
+
# self.trilinear_ = bing_lut_trilinearInterplt()
|
199 |
+
self.trilinear_=Tritri()
|
200 |
+
|
201 |
+
def forward(self,img):
|
202 |
+
pred = self.classifier(img).squeeze()
|
203 |
+
|
204 |
+
# #numpy squeeze方法去掉矩阵中维度为1的维度,返回np.ndarray
|
205 |
+
# LUT = pred[0] * self.LUT0.LUT
|
206 |
+
LUT = pred[0] * self.LUT0.LUT + pred[1] * self.LUT1.LUT + pred[2] * self.LUT2.LUT
|
207 |
+
output = self.trilinear_(LUT, img)
|
208 |
+
# _,output = self.trilinear_(LUT, img)
|
209 |
+
return output
|
210 |
+
# return LUT
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
# class TrilinearInterpolationFunction(torch.autograd.Function):
|
215 |
+
# @staticmethod
|
216 |
+
# def forward(ctx, lut, x):
|
217 |
+
|
218 |
+
# x = x.contiguous()
|
219 |
+
|
220 |
+
# output = x.new(x.size())
|
221 |
+
# dim = lut.size()[-1]
|
222 |
+
# shift = dim ** 3
|
223 |
+
# binsize = 1.000001 / (dim-1)
|
224 |
+
# W = x.size(2)
|
225 |
+
# H = x.size(3)
|
226 |
+
# batch = x.size(0)
|
227 |
+
# #trilinear这个包是作者自己实现的
|
228 |
+
# assert 1 == trilinear.forward(lut,
|
229 |
+
# x,
|
230 |
+
# output,
|
231 |
+
# dim,
|
232 |
+
# shift,
|
233 |
+
# binsize,
|
234 |
+
# W,
|
235 |
+
# H,
|
236 |
+
# batch)
|
237 |
+
|
238 |
+
# int_package = torch.IntTensor([dim, shift, W, H, batch])
|
239 |
+
# float_package = torch.FloatTensor([binsize])
|
240 |
+
# variables = [lut, x, int_package, float_package]
|
241 |
+
|
242 |
+
# ctx.save_for_backward(*variables)
|
243 |
+
|
244 |
+
# return lut, output
|
245 |
+
|
246 |
+
# @staticmethod
|
247 |
+
# def backward(ctx, lut_grad, x_grad):
|
248 |
+
|
249 |
+
# lut, x, int_package, float_package = ctx.saved_variables
|
250 |
+
# dim, shift, W, H, batch = int_package
|
251 |
+
# dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch)
|
252 |
+
# binsize = float(float_package[0])
|
253 |
+
|
254 |
+
# assert 1 == trilinear.backward(x,
|
255 |
+
# x_grad,
|
256 |
+
# lut_grad,
|
257 |
+
# dim,
|
258 |
+
# shift,
|
259 |
+
# binsize,
|
260 |
+
# W,
|
261 |
+
# H,
|
262 |
+
# batch)
|
263 |
+
# return lut_grad, x_grad
|
264 |
+
|
265 |
+
|
266 |
+
# class TrilinearInterpolation(torch.nn.Module):
|
267 |
+
# def __init__(self):
|
268 |
+
# super(TrilinearInterpolation, self).__init__()
|
269 |
+
|
270 |
+
# def forward(self, lut, x):
|
271 |
+
# return TrilinearInterpolationFunction.apply(lut, x)
|
272 |
+
|
273 |
+
|
274 |
+
class TV_3D(nn.Module):
|
275 |
+
def __init__(self, dim=33):
|
276 |
+
super(TV_3D,self).__init__()
|
277 |
+
|
278 |
+
self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float)
|
279 |
+
self.weight_r[:,:,:,(0,dim-2)] *= 2.0
|
280 |
+
self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float)
|
281 |
+
self.weight_g[:,:,(0,dim-2),:] *= 2.0
|
282 |
+
self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float)
|
283 |
+
self.weight_b[:,(0,dim-2),:,:] *= 2.0
|
284 |
+
self.relu = torch.nn.ReLU()
|
285 |
+
|
286 |
+
def forward(self, LUT):
|
287 |
+
|
288 |
+
dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:]
|
289 |
+
dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:]
|
290 |
+
dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:]
|
291 |
+
tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b))
|
292 |
+
|
293 |
+
mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b))
|
294 |
+
|
295 |
+
return tv, mn
|
296 |
+
|
297 |
+
|
298 |
+
##new by bing##
|
299 |
+
if __name__=='__main__':
|
300 |
+
def img_process_256(img):
|
301 |
+
# 将PIL类型的图片文件(mode=RGB size=3840x2160,三通道)转换为tensor,tensor维度是[N,C,H,W](即[1,3,256,256])
|
302 |
+
img=img.resize((256,256))
|
303 |
+
trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
|
304 |
+
img = trans(img)
|
305 |
+
img = torch.unsqueeze(img,0) # 填充一维
|
306 |
+
print("img",img.size())
|
307 |
+
# # 将其由HWC格式改成NCHW格式,N=1
|
308 |
+
# img=np.array(img)
|
309 |
+
return img
|
310 |
+
|
311 |
+
def img_process_4k(img):
|
312 |
+
# 将PIL类型的图片文件(mode=RGB size=3840x2160,三通道)转换为tensor,tensor维度是[N,C,H,W](即[1,3,256,256])
|
313 |
+
trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
|
314 |
+
img = trans(img)
|
315 |
+
img = torch.unsqueeze(img,0) # 填充一维
|
316 |
+
print("img",img.size())
|
317 |
+
# # 将其由HWC格式改成NCHW格式,N=1
|
318 |
+
# img=np.array(img)
|
319 |
+
return img
|
320 |
+
|
321 |
+
|
322 |
+
img_ori=Image.open("/home/elle/bing/proj/code/download-4k-img/picture/%s" % ("X4_Animal2_BIC_g_03.png"))
|
323 |
+
img=img_process_256(img_ori)
|
324 |
+
img_4k=img_process_4k(img_ori)
|
325 |
+
model=LUT_all()
|
326 |
+
|
327 |
+
out=model(img_4k)
|
328 |
+
print(out)
|
329 |
+
|
models/trilinear_test.py
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from re import A
|
2 |
+
import time
|
3 |
+
from turtle import width
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
##new####
|
10 |
+
# https://github.com/tedyhabtegebrial/PyTorch-Trilinear-Interpolation
|
11 |
+
class TrilinearIntepolation(nn.Module):
|
12 |
+
"""TrilinearIntepolation in PyTorch."""
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
super(TrilinearIntepolation, self).__init__()
|
16 |
+
|
17 |
+
def sample_at_integer_locs(self, input_feats, index_tensor):
|
18 |
+
assert input_feats.ndimension()==5, 'input_feats should be of shape [Batch,F,D,Height,Width]'
|
19 |
+
assert index_tensor.ndimension()==4, 'index_tensor should be of shape [Batch,Height,Width,3]'
|
20 |
+
# first sample pixel locations using nearest neighbour interpolation
|
21 |
+
batch_size, num_chans, num_d, height, width = input_feats.shape
|
22 |
+
grid_height, grid_width = index_tensor.shape[1],index_tensor.shape[2]
|
23 |
+
|
24 |
+
xy_grid = index_tensor[..., 0:2]
|
25 |
+
# 0:2是包括0但是不包括2的,因此取出来的是最后一个维度的0维和1维
|
26 |
+
xy_grid[..., 0] = xy_grid[..., 0] - ((width-1.0)/2.0)
|
27 |
+
xy_grid[..., 0] = xy_grid[..., 0] / ((width-1.0)/2.0)
|
28 |
+
xy_grid[..., 1] = xy_grid[..., 1] - ((height-1.0)/2.0)
|
29 |
+
xy_grid[..., 1] = xy_grid[..., 1] / ((height-1.0)/2.0)
|
30 |
+
xy_grid = torch.clamp(xy_grid, min=-1.0, max=1.0)
|
31 |
+
#clamp限制每个元素的最大值和最小值
|
32 |
+
sampled_in_2d = F.grid_sample(input=input_feats.view(batch_size, num_chans*num_d, height, width),
|
33 |
+
grid=xy_grid, mode='nearest').view(batch_size, num_chans, num_d, grid_height, grid_width)
|
34 |
+
# grid_sample双线性插值https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html?highlight=grid_sample#torch.nn.functional.grid_sample
|
35 |
+
# view函数https://blog.csdn.net/york1996/article/details/81949843
|
36 |
+
z_grid = index_tensor[..., 2].view(batch_size, 1, 1, grid_height, grid_width)
|
37 |
+
z_grid = z_grid.long().clamp(min=0, max=num_d-1)
|
38 |
+
# .long()将张量转换为int64类型
|
39 |
+
z_grid = z_grid.expand(batch_size,num_chans, 1, grid_height, grid_width)
|
40 |
+
# expand对原张量中维度为1的维度进行扩展 https://blog.csdn.net/weixin_42782150/article/details/108615706
|
41 |
+
# 本例中是使用expand对dim=1的维度进行扩展,扩展成num_chans
|
42 |
+
sampled_in_3d = sampled_in_2d.gather(2, z_grid).squeeze(2)
|
43 |
+
return sampled_in_3d
|
44 |
+
|
45 |
+
|
46 |
+
def forward(self, input_feats, sampling_grid):
|
47 |
+
assert input_feats.ndimension()==5, 'input_feats should be of shape [B,F,D,H,W]'
|
48 |
+
assert sampling_grid.ndimension()==4, 'sampling_grid should be of shape [B,H,W,3]'
|
49 |
+
batch_size, num_chans, num_d, height, width = input_feats.shape
|
50 |
+
grid_height, grid_width = sampling_grid.shape[1],sampling_grid.shape[2]
|
51 |
+
# make sure sampling grid lies between -1, 1
|
52 |
+
sampling_grid = torch.clamp(sampling_grid, min=-1.0, max=1.0)
|
53 |
+
# map to 0,1
|
54 |
+
sampling_grid = (sampling_grid+1)/2.0
|
55 |
+
# Scale grid to floating point pixel locations
|
56 |
+
scaling_factor = torch.FloatTensor([width-1.0, height-1.0, num_d-1.0]).to(input_feats.device).view(1, 1, 1, 3)
|
57 |
+
sampling_grid = scaling_factor*sampling_grid
|
58 |
+
# Now sampling grid is between [0, w-1; 0,h-1; 0,d-1]
|
59 |
+
x, y, z = torch.split(sampling_grid, split_size_or_sections=1, dim=3)
|
60 |
+
#这个(x,y,z)是输入的浮点数(在这篇文章中是每个像素点的rgb值)
|
61 |
+
#这个(x0,y0,z0)是输入的浮点数向下取整
|
62 |
+
#把sampling_grid维度是3的那个维度切成每份大小为1
|
63 |
+
x_0, y_0, z_0 = torch.split(sampling_grid.floor(), split_size_or_sections=1, dim=3)
|
64 |
+
x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0
|
65 |
+
u, v, w = x-x_0, y-y_0, z-z_0
|
66 |
+
print("v:",x_0,y_0,z_0)
|
67 |
+
print("s:",x_0.size(),y_0.size(),z_0.size())
|
68 |
+
print("size,cat",torch.cat([x_0, y_0, z_0],dim=3).size())
|
69 |
+
u, v, w = map(lambda x:x.view(batch_size, 1, grid_height, grid_width).expand(
|
70 |
+
batch_size, num_chans, grid_height, grid_width), [u, v, w])
|
71 |
+
c_000 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_0, z_0], dim=3))
|
72 |
+
# torch.cat 函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。
|
73 |
+
c_001 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_0, z_1], dim=3))
|
74 |
+
c_010 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_1, z_0], dim=3))
|
75 |
+
c_011 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_1, z_1], dim=3))
|
76 |
+
c_100 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_0, z_0], dim=3))
|
77 |
+
c_101 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_0, z_1], dim=3))
|
78 |
+
c_110 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_1, z_0], dim=3))
|
79 |
+
c_111 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_1, z_1], dim=3))
|
80 |
+
c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
|
81 |
+
(1.0-u)*(1.0-v)*(w)*c_001 + \
|
82 |
+
(1.0-u)*(v)*(1.0-w)*c_010 + \
|
83 |
+
(1.0-u)*(v)*(w)*c_011 + \
|
84 |
+
(u)*(1.0-v)*(1.0-w)*c_100 + \
|
85 |
+
(u)*(1.0-v)*(w)*c_101 + \
|
86 |
+
(u)*(v)*(1.0-w)*c_110 + \
|
87 |
+
(u)*(v)*(w)*c_111
|
88 |
+
return c_xyz
|
89 |
+
# class bing_lut_trilinearInterplt(nn.Module):
|
90 |
+
|
91 |
+
# def __init__(self):
|
92 |
+
# super(bing_lut_trilinearInterplt, self).__init__()
|
93 |
+
|
94 |
+
# def test(self,LUT,img_input):
|
95 |
+
# # batch_size, num_chans, height, width = img_input.shape
|
96 |
+
# # grid_height, grid_width = LUT.shape[1],LUT.shape[2]
|
97 |
+
# grid_in=img_input.transpose(1,2).transpose(2,3)
|
98 |
+
# # 原本img_input NCHW,改成 NHWC
|
99 |
+
# xy_grid=grid_in[...,0:2]
|
100 |
+
# yz_grid=grid_in[...,1:3]
|
101 |
+
# #只取3通道中的第0和第1通道(0:2不含2)
|
102 |
+
# input_LUT=LUT[:,:,0,:]
|
103 |
+
# input_LUT_ori=input_LUT.squeeze(2)
|
104 |
+
# # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了
|
105 |
+
# input_LUT=input_LUT_ori[...,0:2]
|
106 |
+
# input_LUT2=input_LUT_ori[...,1:]
|
107 |
+
# print("input_LUT2.size()",input_LUT2.size())
|
108 |
+
# # LUT[33,33,2]
|
109 |
+
# input_LUT=input_LUT.transpose(1,2).transpose(0,1)
|
110 |
+
# input_LUT2=input_LUT2.transpose(1,2).transpose(0,1)
|
111 |
+
# # LUT[2,33,33]
|
112 |
+
# input_LUT=input_LUT.unsqueeze(0)
|
113 |
+
# input_LUT2=input_LUT2.unsqueeze(0)
|
114 |
+
# print(input_LUT.size())
|
115 |
+
# print(input_LUT2.size())
|
116 |
+
# print(grid_in.size())
|
117 |
+
# sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest')
|
118 |
+
# # .view(batch_size, num_chans, num_d, grid_height, grid_width)
|
119 |
+
# sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest')
|
120 |
+
# # .view(batch_size, num_chans, num_d, grid_height, grid_width)
|
121 |
+
|
122 |
+
# # print("sampled_in_2d.size()",sampled_in_2d.size())
|
123 |
+
# # print("sampled_in_2d.size()",sampled_in_2d_2.size())
|
124 |
+
# # # [1,2,2160,3840]
|
125 |
+
# # print("ss")
|
126 |
+
# # print(sampled_in_2d.size())
|
127 |
+
# # print(sampled_in_2d_2.size())
|
128 |
+
# res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1)
|
129 |
+
# print(res.size())
|
130 |
+
# return res
|
131 |
+
# # z_grid = grid_in[..., 2]
|
132 |
+
# # print(z_grid.size())
|
133 |
+
# # # [1,2160,3840]
|
134 |
+
# # print("sss")
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
# def gen_Cout_ijk(self,LUT,x_i,y_i,z_i):
|
139 |
+
# # def gen_Cout_ijk(LUT,x_i,y_i,z_i,channel=3):
|
140 |
+
# # LUT size [3,33,33,33]
|
141 |
+
# # x_i,y_i,z_i size [1,1,2160,3840]
|
142 |
+
# # N=batch_size
|
143 |
+
# #img_input.size()=[1,3,2160,3840]\
|
144 |
+
# # LUT.size()=[3,33,33,33]
|
145 |
+
# # assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)'
|
146 |
+
# channel=3
|
147 |
+
# batch_size,_,height,width=x_i.size()
|
148 |
+
# print(batch_size,height,width)
|
149 |
+
# output=torch.zeros([batch_size,channel,height,width])
|
150 |
+
# # 设置输出大小为[1,3,2160,3840]
|
151 |
+
# if batch_size==1:
|
152 |
+
# # x_i=x_i.view(height*width)
|
153 |
+
# # y_i=y_i.view(height*width)
|
154 |
+
# # z_i=z_i.view(height*width)
|
155 |
+
# x_i=x_i.view(height*width).long()
|
156 |
+
# y_i=y_i.view(height*width).long()
|
157 |
+
# z_i=z_i.view(height*width).long()
|
158 |
+
# # x_i=x_i.view(1, height*width)
|
159 |
+
# # y_i=y_i.view(1, height*width)
|
160 |
+
# # z_i=z_i.view(1, height*width)
|
161 |
+
# # 2维tensor,[1, 2160*3840]
|
162 |
+
# # xyz_i=torch.cat([x_i,y_i,z_i],dim=0)
|
163 |
+
# # # xyz_i 2维tensor,[3, 2160*3840]
|
164 |
+
|
165 |
+
# # print("xyz_i.size()",xyz_i.size())
|
166 |
+
# else:
|
167 |
+
# print("error:batch size must be 1")
|
168 |
+
# for i in range(height*width):
|
169 |
+
# h_index=int(i/width)
|
170 |
+
# w_index=int(i%width)
|
171 |
+
# # print(h_index)
|
172 |
+
# # print(w_index)
|
173 |
+
# # print(x_i.size())
|
174 |
+
# # print(batch_size)
|
175 |
+
# # print(output.size())
|
176 |
+
# # print(output[0,0,h_index,w_index])
|
177 |
+
# if(i%10000==0):
|
178 |
+
# print(i)
|
179 |
+
# output[batch_size-1,0,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],0]
|
180 |
+
# output[batch_size-1,1,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],1]
|
181 |
+
# output[batch_size-1,2,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],2]
|
182 |
+
|
183 |
+
# # x_i=x_i.view(batch_size,height*width)
|
184 |
+
# # y_i=y_i.view(batch_size,height*width)
|
185 |
+
# # z_i=z_i.view(batch_size,height*width)
|
186 |
+
# # 1,2160*3840
|
187 |
+
|
188 |
+
|
189 |
+
# return output
|
190 |
+
|
191 |
+
|
192 |
+
# def forward(self, LUT, img_input):
|
193 |
+
# assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]'
|
194 |
+
# # N=batch_size
|
195 |
+
# #img_input.size()=[1,3,2160,3840]\
|
196 |
+
# # LUT.size()=[3,33,33,33]
|
197 |
+
# assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)'
|
198 |
+
# batch_size, num_chans, height, width = img_input.shape
|
199 |
+
# dim = LUT.shape[1] # M
|
200 |
+
# img_size=img_input.size()
|
201 |
+
# Cmax=255.0
|
202 |
+
# s=Cmax/dim
|
203 |
+
# r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1)
|
204 |
+
# # 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分
|
205 |
+
# #r,g,b.size()=[1,1,2160,3840]
|
206 |
+
# # r=img_input[:,0,:,:]
|
207 |
+
# # g=img_input[:,1,:,:]
|
208 |
+
# # b=img_input[:,2,:,:]
|
209 |
+
# x=r/s
|
210 |
+
# y=g/s
|
211 |
+
# z=b/s
|
212 |
+
# # tmptmp=self.test(LUT,img_input)
|
213 |
+
# # x,y,z.size=[1,1,,2160,3840]
|
214 |
+
# # x_0,y_0,z_0.size=[1,1,,2160,3840]
|
215 |
+
# # x_1, y_1, z_1.size=[1,1,,2160,3840]
|
216 |
+
# x_0,y_0,z_0=x.floor(),y.floor(),z.floor()
|
217 |
+
# x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0
|
218 |
+
# u, v, w = x-x_0, y-y_0, z-z_0
|
219 |
+
# # u,v,w.size=[1,1,2160,3840]
|
220 |
+
# # print("x_0.size",x_0.size())
|
221 |
+
# c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1))
|
222 |
+
# print(c_000.size())
|
223 |
+
# # x_i是顶点,大小为[1,1,2160,3840]
|
224 |
+
# # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
|
225 |
+
# c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1))
|
226 |
+
# c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1))
|
227 |
+
# c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1))
|
228 |
+
# c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1))
|
229 |
+
# c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1))
|
230 |
+
# c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1))
|
231 |
+
# c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1))
|
232 |
+
|
233 |
+
# # c_000 = self.gen_Cout_ijk(LUT,x_0,y_0,z_0)
|
234 |
+
# # # x_i是顶点,大小为[1,1,2160,3840]
|
235 |
+
# # # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
|
236 |
+
# # c_100 = self.gen_Cout_ijk(LUT,x_1,y_0,z_0)
|
237 |
+
# # c_010 = self.gen_Cout_ijk(LUT,x_0,y_1,z_0)
|
238 |
+
# # c_110 = self.gen_Cout_ijk(LUT,x_1,y_1,z_0)
|
239 |
+
# # c_001 = self.gen_Cout_ijk(LUT,x_0,y_0,z_1)
|
240 |
+
# # c_101 = self.gen_Cout_ijk(LUT,x_1,y_0,z_1)
|
241 |
+
# # c_011 = self.gen_Cout_ijk(LUT,x_0,y_1,z_1)
|
242 |
+
# # c_111 = self.gen_Cout_ijk(LUT,x_1,y_1,z_1)
|
243 |
+
# c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
|
244 |
+
# (1.0-u)*(1.0-v)*(w)*c_001 + \
|
245 |
+
# (1.0-u)*(v)*(1.0-w)*c_010 + \
|
246 |
+
# (1.0-u)*(v)*(w)*c_011 + \
|
247 |
+
# (u)*(1.0-v)*(1.0-w)*c_100 + \
|
248 |
+
# (u)*(1.0-v)*(w)*c_101 + \
|
249 |
+
# (u)*(v)*(1.0-w)*c_110 + \
|
250 |
+
# (u)*(v)*(w)*c_111
|
251 |
+
# # 广播机制,输出[1,3,2160,3840]
|
252 |
+
# print("c_xyz",c_xyz.size())
|
253 |
+
# return c_xyz
|
254 |
+
|
255 |
+
# # id100 = x_0 + 1.0 + y_0 * dim + z_0 * dim * dim
|
256 |
+
# # id010 = x_0 + (y_0 + 1.0) * dim + z_0 * dim * dim
|
257 |
+
# # id110 = x_0 + 1.0 + (y_0 + 1.0) * dim + z_0 * dim * dim
|
258 |
+
# # id001 = x_0 + y_0 * dim + (z_0 + 1.0) * dim * dim
|
259 |
+
# # id101 = x_0 + 1.0 + y_0 * dim + (z_0 + 1.0) * dim * dim
|
260 |
+
# # id011 = x_0 + (y_0 + 1.0) * dim + (z_0 + 1.0) * dim * dim
|
261 |
+
# # id111 = x_0 + 1.0 + (y_0 + 1.0) * dim + (z_0 + 1.0) * dim * dim
|
262 |
+
|
263 |
+
# # w000 = (1.0-u)*(1-v)*(1-w)
|
264 |
+
# # #大概也许得改成点乘
|
265 |
+
# # w100 = u*(1-v)*(1-w)
|
266 |
+
# # w010 = (1-u)*v*(1-w)
|
267 |
+
# # w110 = u*v*(1-w)
|
268 |
+
# # w001 = (1-u)*(1-v)*w
|
269 |
+
# # w101 = u*(1-v)*w
|
270 |
+
# # w011 = (1-u)*v*w
|
271 |
+
# # w111 = u*v*w
|
272 |
+
# # output=
|
273 |
+
|
274 |
+
# # print("v:",x_0,y_0,z_0)
|
275 |
+
# # print("s:",x_0.size(),y_0.size(),z_0.size())
|
276 |
+
# # u,v,w=u/s,v/s,w/s
|
277 |
+
# # c_000 = self.gen_Cout_ijk(x_0,y_0,z_0)
|
278 |
+
# # c_100 = self.gen_Cout_ijk(x_1,y_0,z_0)
|
279 |
+
# # c_010 = self.gen_Cout_ijk(x_0,y_1,z_0)
|
280 |
+
# # c_110 = self.gen_Cout_ijk(x_1,y_1,z_0)
|
281 |
+
# # c_001 = self.gen_Cout_ijk(x_0,y_0,z_1)
|
282 |
+
# # c_101 = self.gen_Cout_ijk(x_1,y_0,z_1)
|
283 |
+
# # c_011 = self.gen_Cout_ijk(x_0,y_1,z_1)
|
284 |
+
# # c_111 = self.gen_Cout_ijk(x_1,y_1,z_1)
|
285 |
+
|
286 |
+
|
287 |
+
# # c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
|
288 |
+
# # (1.0-u)*(1.0-v)*(w)*c_001 + \
|
289 |
+
# # (1.0-u)*(v)*(1.0-w)*c_010 + \
|
290 |
+
# # (1.0-u)*(v)*(w)*c_011 + \
|
291 |
+
# # (u)*(1.0-v)*(1.0-w)*c_100 + \
|
292 |
+
# # (u)*(1.0-v)*(w)*c_101 + \
|
293 |
+
# # (u)*(v)*(1.0-w)*c_110 + \
|
294 |
+
# # (u)*(v)*(w)*c_111
|
295 |
+
# # return c_xyz
|
296 |
+
|
297 |
+
class Tritri(nn.Module):
|
298 |
+
|
299 |
+
def __init__(self):
|
300 |
+
super(Tritri, self).__init__()
|
301 |
+
|
302 |
+
def forward(self,LUT,img):
|
303 |
+
img = (img - .5) * 2.
|
304 |
+
# grid_sample expects NxDxHxWx3 (1x1xHxWx3)
|
305 |
+
img = img.permute(0, 2, 3, 1)[:, None]
|
306 |
+
# add batch dim to LUT
|
307 |
+
LUT = LUT[None]
|
308 |
+
# grid sample
|
309 |
+
result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True)
|
310 |
+
# drop added dimensions and permute back
|
311 |
+
result = result[:, :, 0].permute(0, 2, 3, 1)
|
312 |
+
return result
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
class bing_lut_trilinearInterplt(nn.Module):
|
317 |
+
|
318 |
+
def __init__(self):
|
319 |
+
super(bing_lut_trilinearInterplt, self).__init__()
|
320 |
+
|
321 |
+
def test(self,LUT,img_input):
|
322 |
+
# batch_size, num_chans, height, width = img_input.shape
|
323 |
+
# grid_height, grid_width = LUT.shape[1],LUT.shape[2]
|
324 |
+
grid_in=img_input.transpose(1,2).transpose(2,3)
|
325 |
+
# 1
|
326 |
+
# 原本img_input NCHW,改成 NHWC
|
327 |
+
xy_grid=grid_in[...,0:2]
|
328 |
+
yz_grid=grid_in[...,1:3]
|
329 |
+
# 23
|
330 |
+
#只取3通道中的第0和第1通道(0:2不含2)
|
331 |
+
|
332 |
+
# LUT正确版本应该是[3,33,33,33]
|
333 |
+
# 在这里弄错成为[33,33,33,3]
|
334 |
+
input_LUT=LUT[:,:,:,0:1]
|
335 |
+
input_LUT_ori=input_LUT.squeeze(3)
|
336 |
+
# 45
|
337 |
+
|
338 |
+
# [3,33,33,33]->[3,33,33] 把dim=3的数据丢掉了
|
339 |
+
|
340 |
+
# input_LUT=LUT[:,:,0,:]
|
341 |
+
# input_LUT_ori=input_LUT.squeeze(2)
|
342 |
+
# # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了
|
343 |
+
|
344 |
+
input_LUT=input_LUT_ori[0:2,...]
|
345 |
+
input_LUT2=input_LUT_ori[1:,...]
|
346 |
+
input_LUT=input_LUT.unsqueeze(0)
|
347 |
+
input_LUT2=input_LUT2.unsqueeze(0)
|
348 |
+
# 6-9
|
349 |
+
|
350 |
+
# 都是[1,2,33,33]
|
351 |
+
# print(input_LUT.size())
|
352 |
+
# print("dtype:")
|
353 |
+
# print(input_LUT.dtype)
|
354 |
+
# print(input_LUT2.dtype)
|
355 |
+
# print(xy_grid.dtype)
|
356 |
+
# print(yz_grid.dtype)
|
357 |
+
# input_LUT.int()
|
358 |
+
# input_LUT2.int()
|
359 |
+
# xy_grid.int()
|
360 |
+
# yz_grid.int()
|
361 |
+
|
362 |
+
# # print(grid_in.size())
|
363 |
+
sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest',align_corners=False)
|
364 |
+
# .view(batch_size, num_chans, num_d, grid_height, grid_width)
|
365 |
+
sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest',align_corners=False)
|
366 |
+
# .view(batch_size, num_chans, num_d, grid_height, grid_width)
|
367 |
+
# 10
|
368 |
+
res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1)
|
369 |
+
# print(res.size())
|
370 |
+
return res
|
371 |
+
|
372 |
+
def forward(self, LUT, img_input):
|
373 |
+
assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]'
|
374 |
+
# N=batch_size
|
375 |
+
#img_input.size()=[1,3,2160,3840]\
|
376 |
+
# LUT.size()=[3,33,33,33]
|
377 |
+
assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)'
|
378 |
+
# batch_size, num_chans, height, width = img_input.shape
|
379 |
+
dim = LUT.shape[1] # M
|
380 |
+
# img_size=img_input.size()
|
381 |
+
# Cmax=1.00001
|
382 |
+
Cmax=10
|
383 |
+
s=Cmax/(dim-1.0)
|
384 |
+
s=torch.Tensor([s])
|
385 |
+
#谢谢小黄鸭!!#data types int64 and int32 do not match in BroadcastRel
|
386 |
+
|
387 |
+
r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1)
|
388 |
+
# 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分
|
389 |
+
#r,g,b.size()=[1,1,2160,3840]
|
390 |
+
# r=img_input[:,0,:,:]
|
391 |
+
# g=img_input[:,1,:,:]
|
392 |
+
# b=img_input[:,2,:,:]
|
393 |
+
s=s.to(r.device)
|
394 |
+
x=r/s
|
395 |
+
y=g/s
|
396 |
+
z=b/s
|
397 |
+
# tmptmp=self.test(LUT,img_input)
|
398 |
+
# x,y,z.size=[1,1,,2160,3840]
|
399 |
+
# x_0,y_0,z_0.size=[1,1,,2160,3840]
|
400 |
+
# x_1, y_1, z_1.size=[1,1,,2160,3840]
|
401 |
+
x_0,y_0,z_0=x.floor(),y.floor(),z.floor()
|
402 |
+
x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0
|
403 |
+
u, v, w = x-x_0, y-y_0, z-z_0
|
404 |
+
# u,v,w.size=[1,1,2160,3840]
|
405 |
+
# print("x_0.size",x_0.size())
|
406 |
+
|
407 |
+
c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1))
|
408 |
+
# print(c_000.size())
|
409 |
+
# x_i是顶点,大小为[1,1,2160,3840]
|
410 |
+
# 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
|
411 |
+
c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1))
|
412 |
+
c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1))
|
413 |
+
c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1))
|
414 |
+
c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1))
|
415 |
+
c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1))
|
416 |
+
c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1))
|
417 |
+
c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1))
|
418 |
+
|
419 |
+
c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
|
420 |
+
(1.0-u)*(1.0-v)*(w)*c_001 + \
|
421 |
+
(1.0-u)*(v)*(1.0-w)*c_010 + \
|
422 |
+
(1.0-u)*(v)*(w)*c_011 + \
|
423 |
+
(u)*(1.0-v)*(1.0-w)*c_100 + \
|
424 |
+
(u)*(1.0-v)*(w)*c_101 + \
|
425 |
+
(u)*(v)*(1.0-w)*c_110 + \
|
426 |
+
(u)*(v)*(w)*c_111
|
427 |
+
# 广播机制,输出[1,3,2160,3840]
|
428 |
+
print("c_xyz",c_xyz.size())
|
429 |
+
return c_xyz
|
430 |
+
|
431 |
+
class bing_lut_trilinearInterplt_backup(nn.Module):
|
432 |
+
|
433 |
+
def __init__(self):
|
434 |
+
super(bing_lut_trilinearInterplt, self).__init__()
|
435 |
+
|
436 |
+
def test(self,LUT,img_input):
|
437 |
+
# batch_size, num_chans, height, width = img_input.shape
|
438 |
+
# grid_height, grid_width = LUT.shape[1],LUT.shape[2]
|
439 |
+
grid_in=img_input.transpose(1,2).transpose(2,3)
|
440 |
+
# 1
|
441 |
+
# 原本img_input NCHW,改成 NHWC
|
442 |
+
xy_grid=grid_in[...,0:2]
|
443 |
+
yz_grid=grid_in[...,1:3]
|
444 |
+
# 23
|
445 |
+
#只取3通道中的第0和第1通道(0:2不含2)
|
446 |
+
|
447 |
+
# LUT正确版本应该是[3,33,33,33]
|
448 |
+
# 在这里弄错成为[33,33,33,3]
|
449 |
+
input_LUT=LUT[:,:,:,0:1]
|
450 |
+
input_LUT_ori=input_LUT.squeeze(3)
|
451 |
+
# 45
|
452 |
+
|
453 |
+
# [3,33,33,33]->[3,33,33] 把dim=3的数据丢掉了
|
454 |
+
|
455 |
+
# input_LUT=LUT[:,:,0,:]
|
456 |
+
# input_LUT_ori=input_LUT.squeeze(2)
|
457 |
+
# # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了
|
458 |
+
|
459 |
+
input_LUT=input_LUT_ori[0:2,...]
|
460 |
+
input_LUT2=input_LUT_ori[1:,...]
|
461 |
+
input_LUT=input_LUT.unsqueeze(0)
|
462 |
+
input_LUT2=input_LUT2.unsqueeze(0)
|
463 |
+
# 6-9
|
464 |
+
|
465 |
+
# 都是[1,2,33,33]
|
466 |
+
# print(input_LUT.size())
|
467 |
+
# print("dtype:")
|
468 |
+
# print(input_LUT.dtype)
|
469 |
+
# print(input_LUT2.dtype)
|
470 |
+
# print(xy_grid.dtype)
|
471 |
+
# print(yz_grid.dtype)
|
472 |
+
# input_LUT.int()
|
473 |
+
# input_LUT2.int()
|
474 |
+
# xy_grid.int()
|
475 |
+
# yz_grid.int()
|
476 |
+
|
477 |
+
# # print(grid_in.size())
|
478 |
+
sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest')
|
479 |
+
# .view(batch_size, num_chans, num_d, grid_height, grid_width)
|
480 |
+
sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest')
|
481 |
+
# .view(batch_size, num_chans, num_d, grid_height, grid_width)
|
482 |
+
# 10
|
483 |
+
res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1)
|
484 |
+
# print(res.size())
|
485 |
+
return res
|
486 |
+
|
487 |
+
def forward(self, LUT, img_input):
|
488 |
+
assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]'
|
489 |
+
# N=batch_size
|
490 |
+
#img_input.size()=[1,3,2160,3840]\
|
491 |
+
# LUT.size()=[3,33,33,33]
|
492 |
+
assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)'
|
493 |
+
# batch_size, num_chans, height, width = img_input.shape
|
494 |
+
dim = LUT.shape[1] # M
|
495 |
+
# img_size=img_input.size()
|
496 |
+
Cmax=255.0
|
497 |
+
s=Cmax/dim
|
498 |
+
s=torch.Tensor([s])
|
499 |
+
#谢谢小黄鸭!!#data types int64 and int32 do not match in BroadcastRel
|
500 |
+
|
501 |
+
r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1)
|
502 |
+
# 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分
|
503 |
+
#r,g,b.size()=[1,1,2160,3840]
|
504 |
+
# r=img_input[:,0,:,:]
|
505 |
+
# g=img_input[:,1,:,:]
|
506 |
+
# b=img_input[:,2,:,:]
|
507 |
+
x=r/s
|
508 |
+
y=g/s
|
509 |
+
z=b/s
|
510 |
+
# tmptmp=self.test(LUT,img_input)
|
511 |
+
# x,y,z.size=[1,1,,2160,3840]
|
512 |
+
# x_0,y_0,z_0.size=[1,1,,2160,3840]
|
513 |
+
# x_1, y_1, z_1.size=[1,1,,2160,3840]
|
514 |
+
x_0,y_0,z_0=x.floor(),y.floor(),z.floor()
|
515 |
+
x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0
|
516 |
+
u, v, w = x-x_0, y-y_0, z-z_0
|
517 |
+
# u,v,w.size=[1,1,2160,3840]
|
518 |
+
# print("x_0.size",x_0.size())
|
519 |
+
|
520 |
+
c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1))
|
521 |
+
# print(c_000.size())
|
522 |
+
# x_i是顶点,大小为[1,1,2160,3840]
|
523 |
+
# 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
|
524 |
+
c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1))
|
525 |
+
c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1))
|
526 |
+
c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1))
|
527 |
+
c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1))
|
528 |
+
c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1))
|
529 |
+
c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1))
|
530 |
+
c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1))
|
531 |
+
|
532 |
+
# c_000 = self.gen_Cout_ijk(LUT,x_0,y_0,z_0)
|
533 |
+
# # x_i是顶点,大小为[1,1,2160,3840]
|
534 |
+
# # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
|
535 |
+
# c_100 = self.gen_Cout_ijk(LUT,x_1,y_0,z_0)
|
536 |
+
# c_010 = self.gen_Cout_ijk(LUT,x_0,y_1,z_0)
|
537 |
+
# c_110 = self.gen_Cout_ijk(LUT,x_1,y_1,z_0)
|
538 |
+
# c_001 = self.gen_Cout_ijk(LUT,x_0,y_0,z_1)
|
539 |
+
# c_101 = self.gen_Cout_ijk(LUT,x_1,y_0,z_1)
|
540 |
+
# c_011 = self.gen_Cout_ijk(LUT,x_0,y_1,z_1)
|
541 |
+
# c_111 = self.gen_Cout_ijk(LUT,x_1,y_1,z_1)
|
542 |
+
c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
|
543 |
+
(1.0-u)*(1.0-v)*(w)*c_001 + \
|
544 |
+
(1.0-u)*(v)*(1.0-w)*c_010 + \
|
545 |
+
(1.0-u)*(v)*(w)*c_011 + \
|
546 |
+
(u)*(1.0-v)*(1.0-w)*c_100 + \
|
547 |
+
(u)*(1.0-v)*(w)*c_101 + \
|
548 |
+
(u)*(v)*(1.0-w)*c_110 + \
|
549 |
+
(u)*(v)*(w)*c_111
|
550 |
+
# 广播机制,输出[1,3,2160,3840]
|
551 |
+
print("c_xyz",c_xyz.size())
|
552 |
+
return c_xyz
|
553 |
+
|
554 |
+
|
555 |
+
|
556 |
+
# @staticmethod
|
557 |
+
# def backward(ctx, lut_grad, x_grad):
|
558 |
+
|
559 |
+
# lut, x, int_package, float_package = ctx.saved_variables
|
560 |
+
# dim, shift, W, H, batch = int_package
|
561 |
+
# dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch)
|
562 |
+
# binsize = float(float_package[0])
|
563 |
+
|
564 |
+
# assert 1 == trilinear.backward(x,
|
565 |
+
# x_grad,
|
566 |
+
# lut_grad,
|
567 |
+
# dim,
|
568 |
+
# shift,
|
569 |
+
# binsize,
|
570 |
+
# W,
|
571 |
+
# H,
|
572 |
+
# batch)
|
573 |
+
# return lut_grad, x_grad
|
574 |
+
|
575 |
+
class Tri(nn.Module):
|
576 |
+
def __init__(self):
|
577 |
+
super(Tri,self).__init__()
|
578 |
+
|
579 |
+
if __name__=='__main__':
|
580 |
+
# input_features: shape [B, num_channels, depth, height, width]
|
581 |
+
# sampling_grid: shape [B,depth, height, 3]
|
582 |
+
data = torch.rand(1, 32, 16, 128, 128)
|
583 |
+
# data = torch.rand(1, 3, 16, 128, 128)
|
584 |
+
sampling_grid = (torch.rand(1, 256, 256, 3) - 0.5)*2.0
|
585 |
+
data = data.float().cuda(0)
|
586 |
+
sampling_grid = sampling_grid.float().cuda(0)
|
587 |
+
trilinear_interpolation = TrilinearIntepolation().cuda(0)
|
588 |
+
# LUT.type() torch.cuda.FloatTensor
|
589 |
+
# LUT.size() torch.Size([3, 33, 33, 33])
|
590 |
+
# img: torch.Size([1, 3, 2160, 3840])
|
591 |
+
data2 = torch.rand(1, 3,2160,3840)
|
592 |
+
# LUT2 = torch.rand(33,33,33,3)
|
593 |
+
LUT2 = torch.rand(3,33,33,33)
|
594 |
+
|
595 |
+
trilinear_interpolation2 = bing_lut_trilinearInterplt()
|
596 |
+
t_start = time.time()
|
597 |
+
interp_data2=trilinear_interpolation2(LUT2,data2)
|
598 |
+
|
599 |
+
# interpolated_data = trilinear_interpolation(data, sampling_grid)
|
600 |
+
# print(interpolated_data.shape)
|
601 |
+
torch.cuda.synchronize()
|
602 |
+
print('time per iteration ', time.time()-t_start)
|
603 |
+
# for i in range(100):
|
604 |
+
# t_start = time.time()
|
605 |
+
# interpolated_data = trilinear_interpolation(data, sampling_grid)
|
606 |
+
# print(interpolated_data.shape)
|
607 |
+
# torch.cuda.synchronize()
|
608 |
+
# print('time per iteration ', time.time()-t_start)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch~=1.11.0
|
2 |
+
torchvision~=0.12.0
|
3 |
+
opencv-python~=4.5.5.64
|
4 |
+
pillow~=9.1.1
|
5 |
+
numpy~=1.22.3
|
6 |
+
scipy~=1.8.1
|
torchvision_x_functional.py
ADDED
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import numbers
|
3 |
+
from functools import wraps
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from scipy.ndimage.filters import gaussian_filter
|
10 |
+
|
11 |
+
__numpy_type_map = {
|
12 |
+
'float64': torch.DoubleTensor,
|
13 |
+
'float32': torch.FloatTensor,
|
14 |
+
'float16': torch.HalfTensor,
|
15 |
+
'int64': torch.LongTensor,
|
16 |
+
'int32': torch.IntTensor,
|
17 |
+
'int16': torch.ShortTensor,
|
18 |
+
'uint16': torch.ShortTensor,
|
19 |
+
'int8': torch.CharTensor,
|
20 |
+
'uint8': torch.ByteTensor,
|
21 |
+
}
|
22 |
+
|
23 |
+
'''image functional utils
|
24 |
+
|
25 |
+
'''
|
26 |
+
|
27 |
+
# NOTE: all the function should recive the ndarray like image, should be W x H x C or W x H
|
28 |
+
|
29 |
+
# 如果将所有输出的维度够搞成height,width,channel 那么可以不用to_tensor??, 不行
|
30 |
+
def preserve_channel_dim(func):
|
31 |
+
"""Preserve dummy channel dim."""
|
32 |
+
@wraps(func)
|
33 |
+
def wrapped_function(img, *args, **kwargs):
|
34 |
+
shape = img.shape
|
35 |
+
result = func(img, *args, **kwargs)
|
36 |
+
if len(shape) == 3 and shape[-1] == 1 and len(result.shape) == 2:
|
37 |
+
result = np.expand_dims(result, axis=-1)
|
38 |
+
return result
|
39 |
+
|
40 |
+
return wrapped_function
|
41 |
+
|
42 |
+
|
43 |
+
def _is_tensor_image(img):
|
44 |
+
return torch.is_tensor(img) and img.ndimension() == 3
|
45 |
+
|
46 |
+
|
47 |
+
def _is_numpy_image(img):
|
48 |
+
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
49 |
+
|
50 |
+
|
51 |
+
def to_tensor(img):
|
52 |
+
'''convert numpy.ndarray to torch tensor. \n
|
53 |
+
if the image is uint8 , it will be divided by 255;\n
|
54 |
+
if the image is uint16 , it will be divided by 65535;\n
|
55 |
+
if the image is float , it will not be divided, we suppose your image range should between [0~1] ;\n
|
56 |
+
|
57 |
+
Arguments:
|
58 |
+
img {numpy.ndarray} -- image to be converted to tensor.
|
59 |
+
'''
|
60 |
+
if not _is_numpy_image(img):
|
61 |
+
raise TypeError('data should be numpy ndarray. but got {}'.format(type(img)))
|
62 |
+
|
63 |
+
if img.ndim == 2:
|
64 |
+
img = img[:, :, None]
|
65 |
+
|
66 |
+
if img.dtype == np.uint8:
|
67 |
+
img = img.astype(np.float32)/255
|
68 |
+
elif img.dtype == np.uint16:
|
69 |
+
img = img.astype(np.float32)/65535
|
70 |
+
elif img.dtype in [np.float32, np.float64]:
|
71 |
+
img = img.astype(np.float32)/1
|
72 |
+
else:
|
73 |
+
raise TypeError('{} is not support'.format(img.dtype))
|
74 |
+
|
75 |
+
img = torch.from_numpy(img.transpose((2, 0, 1)))
|
76 |
+
|
77 |
+
return img
|
78 |
+
|
79 |
+
|
80 |
+
def to_pil_image(tensor):
|
81 |
+
# TODO
|
82 |
+
pass
|
83 |
+
|
84 |
+
|
85 |
+
def to_tiff_image(tensor):
|
86 |
+
# TODO
|
87 |
+
pass
|
88 |
+
|
89 |
+
|
90 |
+
def normalize(tensor, mean, std, inplace=False):
|
91 |
+
"""Normalize a tensor image with mean and standard deviation.
|
92 |
+
|
93 |
+
.. note::
|
94 |
+
This transform acts out of place by default, i.e., it does not mutates the input tensor.
|
95 |
+
|
96 |
+
See :class:`~torchsat.transforms.Normalize` for more details.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
100 |
+
mean (sequence): Sequence of means for each channel.
|
101 |
+
std (sequence): Sequence of standard deviations for each channel.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
Tensor: Normalized Tensor image.
|
105 |
+
"""
|
106 |
+
if not _is_tensor_image(tensor):
|
107 |
+
raise TypeError('tensor is not a torch image.')
|
108 |
+
|
109 |
+
if not inplace:
|
110 |
+
tensor = tensor.clone()
|
111 |
+
|
112 |
+
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
113 |
+
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
114 |
+
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
|
115 |
+
return tensor
|
116 |
+
|
117 |
+
def noise(img, mode='gaussain', percent=0.02):
|
118 |
+
"""
|
119 |
+
TODO: Not good for uint16 data
|
120 |
+
"""
|
121 |
+
original_dtype = img.dtype
|
122 |
+
if mode == 'gaussian':
|
123 |
+
mean = 0
|
124 |
+
var = 0.1
|
125 |
+
sigma = var*0.5
|
126 |
+
|
127 |
+
if img.ndim == 2:
|
128 |
+
h, w = img.shape
|
129 |
+
gauss = np.random.normal(mean, sigma, (h, w))
|
130 |
+
else:
|
131 |
+
h, w, c = img.shape
|
132 |
+
gauss = np.random.normal(mean, sigma, (h, w, c))
|
133 |
+
|
134 |
+
if img.dtype not in [np.float32, np.float64]:
|
135 |
+
gauss = gauss * np.iinfo(img.dtype).max
|
136 |
+
img = np.clip(img.astype(np.float) + gauss, 0, np.iinfo(img.dtype).max)
|
137 |
+
else:
|
138 |
+
img = np.clip(img.astype(np.float) + gauss, 0, 1)
|
139 |
+
|
140 |
+
elif mode == 'salt':
|
141 |
+
print(img.dtype)
|
142 |
+
s_vs_p = 1
|
143 |
+
num_salt = np.ceil(percent * img.size * s_vs_p)
|
144 |
+
coords = tuple([np.random.randint(0, i - 1, int(num_salt)) for i in img.shape])
|
145 |
+
|
146 |
+
if img.dtype in [np.float32, np.float64]:
|
147 |
+
img[coords] = 1
|
148 |
+
else:
|
149 |
+
img[coords] = np.iinfo(img.dtype).max
|
150 |
+
print(img.dtype)
|
151 |
+
elif mode == 'pepper':
|
152 |
+
s_vs_p = 0
|
153 |
+
num_pepper = np.ceil(percent * img.size * (1. - s_vs_p))
|
154 |
+
coords = tuple([np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape])
|
155 |
+
img[coords] = 0
|
156 |
+
|
157 |
+
elif mode == 's&p':
|
158 |
+
s_vs_p = 0.5
|
159 |
+
|
160 |
+
# Salt mode
|
161 |
+
num_salt = np.ceil(percent * img.size * s_vs_p)
|
162 |
+
coords = tuple([np.random.randint(0, i - 1, int(num_salt)) for i in img.shape])
|
163 |
+
if img.dtype in [np.float32, np.float64]:
|
164 |
+
img[coords] = 1
|
165 |
+
else:
|
166 |
+
img[coords] = np.iinfo(img.dtype).max
|
167 |
+
|
168 |
+
# Pepper mode
|
169 |
+
num_pepper = np.ceil(percent* img.size * (1. - s_vs_p))
|
170 |
+
coords = tuple([np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape])
|
171 |
+
img[coords] = 0
|
172 |
+
else:
|
173 |
+
raise ValueError('not support mode for {}'.format(mode))
|
174 |
+
|
175 |
+
noisy = img.astype(original_dtype)
|
176 |
+
|
177 |
+
return noisy
|
178 |
+
|
179 |
+
|
180 |
+
def gaussian_blur(img, kernel_size):
|
181 |
+
# When sigma=0, it is computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`
|
182 |
+
return cv2.GaussianBlur(img, (kernel_size, kernel_size), sigmaX=0)
|
183 |
+
|
184 |
+
|
185 |
+
def adjust_brightness(img, value=0):
|
186 |
+
if img.dtype in [np.float, np.float32, np.float64, np.float128]:
|
187 |
+
dtype_min, dtype_max = 0, 1
|
188 |
+
dtype = np.float32
|
189 |
+
else:
|
190 |
+
dtype_min = np.iinfo(img.dtype).min
|
191 |
+
dtype_max = np.iinfo(img.dtype).max
|
192 |
+
dtype = np.iinfo(img.dtype)
|
193 |
+
|
194 |
+
result = np.clip(img.astype(np.float)+value, dtype_min, dtype_max).astype(dtype)
|
195 |
+
|
196 |
+
return result
|
197 |
+
|
198 |
+
|
199 |
+
def adjust_contrast(img, factor):
|
200 |
+
if img.dtype in [np.float, np.float32, np.float64, np.float128]:
|
201 |
+
dtype_min, dtype_max = 0, 1
|
202 |
+
dtype = np.float32
|
203 |
+
else:
|
204 |
+
dtype_min = np.iinfo(img.dtype).min
|
205 |
+
dtype_max = np.iinfo(img.dtype).max
|
206 |
+
dtype = np.iinfo(img.dtype)
|
207 |
+
|
208 |
+
result = np.clip(img.astype(np.float)*factor, dtype_min, dtype_max).astype(dtype)
|
209 |
+
|
210 |
+
return result
|
211 |
+
|
212 |
+
def adjust_saturation():
|
213 |
+
# TODO
|
214 |
+
pass
|
215 |
+
|
216 |
+
def adjust_hue():
|
217 |
+
# TODO
|
218 |
+
pass
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
def to_grayscale(img, output_channels=1):
|
223 |
+
"""convert input ndarray image to gray sacle image.
|
224 |
+
|
225 |
+
Arguments:
|
226 |
+
img {ndarray} -- the input ndarray image
|
227 |
+
|
228 |
+
Keyword Arguments:
|
229 |
+
output_channels {int} -- output gray image channel (default: {1})
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
ndarray -- gray scale ndarray image
|
233 |
+
"""
|
234 |
+
if img.ndim == 2:
|
235 |
+
gray_img = img
|
236 |
+
elif img.shape[2] == 3:
|
237 |
+
gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
238 |
+
else:
|
239 |
+
gray_img = np.mean(img, axis=2)
|
240 |
+
gray_img = gray_img.astype(img.dtype)
|
241 |
+
|
242 |
+
if output_channels != 1:
|
243 |
+
gray_img = np.tile(gray_img, (output_channels, 1, 1))
|
244 |
+
gray_img = np.transpose(gray_img, [1,2,0])
|
245 |
+
|
246 |
+
return gray_img
|
247 |
+
|
248 |
+
|
249 |
+
def shift(img, top, left):
|
250 |
+
(h, w) = img.shape[0:2]
|
251 |
+
matrix = np.float32([[1, 0, left], [0, 1, top]])
|
252 |
+
dst = cv2.warpAffine(img, matrix, (w, h))
|
253 |
+
|
254 |
+
return dst
|
255 |
+
|
256 |
+
|
257 |
+
def rotate(img, angle, center=None, scale=1.0):
|
258 |
+
(h, w) = img.shape[:2]
|
259 |
+
|
260 |
+
if center is None:
|
261 |
+
center = (w / 2, h / 2)
|
262 |
+
|
263 |
+
M = cv2.getRotationMatrix2D(center, angle, scale)
|
264 |
+
rotated = cv2.warpAffine(img, M, (w, h))
|
265 |
+
|
266 |
+
return rotated
|
267 |
+
|
268 |
+
|
269 |
+
def resize(img, size, interpolation=Image.BILINEAR):
|
270 |
+
'''resize the image
|
271 |
+
TODO: opencv resize 之后图像就成了0~1了
|
272 |
+
Arguments:
|
273 |
+
img {ndarray} -- the input ndarray image
|
274 |
+
size {int, iterable} -- the target size, if size is intger, width and height will be resized to same \
|
275 |
+
otherwise, the size should be tuple (height, width) or list [height, width]
|
276 |
+
|
277 |
+
|
278 |
+
Keyword Arguments:
|
279 |
+
interpolation {Image} -- the interpolation method (default: {Image.BILINEAR})
|
280 |
+
|
281 |
+
Raises:
|
282 |
+
TypeError -- img should be ndarray
|
283 |
+
ValueError -- size should be intger or iterable vaiable and length should be 2.
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
img -- resize ndarray image
|
287 |
+
'''
|
288 |
+
|
289 |
+
if not _is_numpy_image(img):
|
290 |
+
raise TypeError('img shoud be ndarray image [w, h, c] or [w, h], but got {}'.format(type(img)))
|
291 |
+
if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size)==2)):
|
292 |
+
raise ValueError('size should be intger or iterable vaiable(length is 2), but got {}'.format(type(size)))
|
293 |
+
|
294 |
+
if isinstance(size, int):
|
295 |
+
height, width = (size, size)
|
296 |
+
else:
|
297 |
+
height, width = (size[0], size[1])
|
298 |
+
|
299 |
+
return cv2.resize(img, (width, height), interpolation=interpolation)
|
300 |
+
|
301 |
+
|
302 |
+
def pad(img, padding, fill=0, padding_mode='constant'):
|
303 |
+
if isinstance(padding, int):
|
304 |
+
pad_left = pad_right = pad_top = pad_bottom = padding
|
305 |
+
if isinstance(padding, collections.Iterable) and len(padding) == 2:
|
306 |
+
pad_left = pad_right = padding[0]
|
307 |
+
pad_bottom = pad_top = padding[1]
|
308 |
+
if isinstance(padding, collections.Iterable) and len(padding) == 4:
|
309 |
+
pad_left = padding[0]
|
310 |
+
pad_top = padding[1]
|
311 |
+
pad_right = padding[2]
|
312 |
+
pad_bottom = padding[3]
|
313 |
+
|
314 |
+
if img.ndim == 2:
|
315 |
+
if padding_mode == 'constant':
|
316 |
+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode, constant_values=fill)
|
317 |
+
else:
|
318 |
+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
|
319 |
+
if img.ndim == 3:
|
320 |
+
if padding_mode == 'constant':
|
321 |
+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode=padding_mode, constant_values=fill)
|
322 |
+
else:
|
323 |
+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode=padding_mode)
|
324 |
+
return img
|
325 |
+
|
326 |
+
|
327 |
+
def crop(img, top, left, height, width):
|
328 |
+
'''crop image
|
329 |
+
|
330 |
+
Arguments:
|
331 |
+
img {ndarray} -- image to be croped
|
332 |
+
top {int} -- top size
|
333 |
+
left {int} -- left size
|
334 |
+
height {int} -- croped height
|
335 |
+
width {int} -- croped width
|
336 |
+
'''
|
337 |
+
if not _is_numpy_image(img):
|
338 |
+
raise TypeError('the input image should be numpy ndarray with dimension 2 or 3.'
|
339 |
+
'but got {}'.format(type(img))
|
340 |
+
)
|
341 |
+
|
342 |
+
if width<0 or height<0 or left <0 or height<0:
|
343 |
+
raise ValueError('the input left, top, width, height should be greater than 0'
|
344 |
+
'but got left={}, top={} width={} height={}'.format(left, top, width, height)
|
345 |
+
)
|
346 |
+
if img.ndim == 2:
|
347 |
+
img_height, img_width = img.shape
|
348 |
+
else:
|
349 |
+
img_height, img_width, _ = img.shape
|
350 |
+
if (left+width) > img_width or (top+height) > img_height:
|
351 |
+
raise ValueError('the input crop width and height should be small or \
|
352 |
+
equal to image width and height. ')
|
353 |
+
|
354 |
+
if img.ndim == 2:
|
355 |
+
return img[top:(top+height), left:(left+width)]
|
356 |
+
elif img.ndim == 3:
|
357 |
+
return img[top:(top+height), left:(left+width), :]
|
358 |
+
|
359 |
+
|
360 |
+
def center_crop(img, output_size):
|
361 |
+
'''crop image
|
362 |
+
|
363 |
+
Arguments:
|
364 |
+
img {ndarray} -- input image
|
365 |
+
output_size {number or sequence} -- the output image size. if sequence, should be [h, w]
|
366 |
+
|
367 |
+
Raises:
|
368 |
+
ValueError -- the input image is large than original image.
|
369 |
+
|
370 |
+
Returns:
|
371 |
+
ndarray image -- return croped ndarray image.
|
372 |
+
'''
|
373 |
+
if img.ndim == 2:
|
374 |
+
img_height, img_width = img.shape
|
375 |
+
else:
|
376 |
+
img_height, img_width, _ = img.shape
|
377 |
+
|
378 |
+
if isinstance(output_size, numbers.Number):
|
379 |
+
output_size = (int(output_size), int(output_size))
|
380 |
+
if output_size[0] > img_height or output_size[1] > img_width:
|
381 |
+
raise ValueError('the output_size should not greater than image size, but got {}'.format(output_size))
|
382 |
+
|
383 |
+
target_height, target_width = output_size
|
384 |
+
|
385 |
+
top = int(round((img_height - target_height)/2))
|
386 |
+
left = int(round((img_width - target_width)/2))
|
387 |
+
|
388 |
+
return crop(img, top, left, target_height, target_width)
|
389 |
+
|
390 |
+
|
391 |
+
def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR):
|
392 |
+
|
393 |
+
img = crop(img, top, left, height, width)
|
394 |
+
img = resize(img, size, interpolation)
|
395 |
+
return img
|
396 |
+
|
397 |
+
def vflip(img):
|
398 |
+
return cv2.flip(img, 0)
|
399 |
+
|
400 |
+
def hflip(img):
|
401 |
+
return cv2.flip(img, 1)
|
402 |
+
|
403 |
+
def flip(img, flip_code):
|
404 |
+
return cv2.flip(img, flip_code)
|
405 |
+
|
406 |
+
|
407 |
+
def elastic_transform(image, alpha, sigma, alpha_affine, interpolation=cv2.INTER_LINEAR,
|
408 |
+
border_mode=cv2.BORDER_REFLECT_101, random_state=None, approximate=False):
|
409 |
+
"""Elastic deformation of images as described in [Simard2003]_ (with modifications).
|
410 |
+
Based on https://gist.github.com/erniejunior/601cdf56d2b424757de5
|
411 |
+
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
|
412 |
+
Convolutional Neural Networks applied to Visual Document Analysis", in
|
413 |
+
Proc. of the International Conference on Document Analysis and
|
414 |
+
Recognition, 2003.
|
415 |
+
"""
|
416 |
+
if random_state is None:
|
417 |
+
random_state = np.random.RandomState(1234)
|
418 |
+
|
419 |
+
height, width = image.shape[:2]
|
420 |
+
|
421 |
+
# Random affine
|
422 |
+
center_square = np.float32((height, width)) // 2
|
423 |
+
square_size = min((height, width)) // 3
|
424 |
+
alpha = float(alpha)
|
425 |
+
sigma = float(sigma)
|
426 |
+
alpha_affine = float(alpha_affine)
|
427 |
+
|
428 |
+
pts1 = np.float32([center_square + square_size, [center_square[0] + square_size, center_square[1] - square_size],
|
429 |
+
center_square - square_size])
|
430 |
+
pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32)
|
431 |
+
matrix = cv2.getAffineTransform(pts1, pts2)
|
432 |
+
|
433 |
+
image = cv2.warpAffine(image, matrix, (width, height), flags=interpolation, borderMode=border_mode)
|
434 |
+
|
435 |
+
if approximate:
|
436 |
+
# Approximate computation smooth displacement map with a large enough kernel.
|
437 |
+
# On large images (512+) this is approximately 2X times faster
|
438 |
+
dx = (random_state.rand(height, width).astype(np.float32) * 2 - 1)
|
439 |
+
cv2.GaussianBlur(dx, (17, 17), sigma, dst=dx)
|
440 |
+
dx *= alpha
|
441 |
+
|
442 |
+
dy = (random_state.rand(height, width).astype(np.float32) * 2 - 1)
|
443 |
+
cv2.GaussianBlur(dy, (17, 17), sigma, dst=dy)
|
444 |
+
dy *= alpha
|
445 |
+
else:
|
446 |
+
dx = np.float32(gaussian_filter((random_state.rand(height, width) * 2 - 1), sigma) * alpha)
|
447 |
+
dy = np.float32(gaussian_filter((random_state.rand(height, width) * 2 - 1), sigma) * alpha)
|
448 |
+
|
449 |
+
x, y = np.meshgrid(np.arange(width), np.arange(height))
|
450 |
+
|
451 |
+
mapx = np.float32(x + dx)
|
452 |
+
mapy = np.float32(y + dy)
|
453 |
+
|
454 |
+
return cv2.remap(image, mapx, mapy, interpolation, borderMode=border_mode)
|
455 |
+
|
456 |
+
|
457 |
+
def bbox_shift(bboxes, top, left):
|
458 |
+
pass
|
459 |
+
|
460 |
+
|
461 |
+
def bbox_vflip(bboxes, img_height):
|
462 |
+
"""vertical flip the bboxes
|
463 |
+
...........
|
464 |
+
. .
|
465 |
+
. .
|
466 |
+
>...........<
|
467 |
+
. .
|
468 |
+
. .
|
469 |
+
...........
|
470 |
+
Args:
|
471 |
+
bbox (ndarray): bbox ndarray [box_nums, 4]
|
472 |
+
flip_code (int, optional): [description]. Defaults to 0.
|
473 |
+
"""
|
474 |
+
flipped = bboxes.copy()
|
475 |
+
flipped[...,1::2] = img_height - bboxes[...,1::2]
|
476 |
+
flipped = flipped[..., [0, 3, 2, 1]]
|
477 |
+
return flipped
|
478 |
+
|
479 |
+
|
480 |
+
def bbox_hflip(bboxes, img_width):
|
481 |
+
"""horizontal flip the bboxes
|
482 |
+
^
|
483 |
+
.............
|
484 |
+
. . .
|
485 |
+
. . .
|
486 |
+
. . .
|
487 |
+
. . .
|
488 |
+
.............
|
489 |
+
^
|
490 |
+
Args:
|
491 |
+
bbox (ndarray): bbox ndarray [box_nums, 4]
|
492 |
+
flip_code (int, optional): [description]. Defaults to 0.
|
493 |
+
"""
|
494 |
+
flipped = bboxes.copy()
|
495 |
+
flipped[..., 0::2] = img_width - bboxes[...,0::2]
|
496 |
+
flipped = flipped[..., [2, 1, 0, 3]]
|
497 |
+
return flipped
|
498 |
+
|
499 |
+
|
500 |
+
def bbox_resize(bboxes, img_size, target_size):
|
501 |
+
"""resize the bbox
|
502 |
+
|
503 |
+
Args:
|
504 |
+
bboxes (ndarray): bbox ndarray [box_nums, 4]
|
505 |
+
img_size (tuple): the image height and width
|
506 |
+
target_size (int, or tuple): the target bbox size.
|
507 |
+
Int or Tuple, if tuple the shape should be (height, width)
|
508 |
+
"""
|
509 |
+
if isinstance(target_size, numbers.Number):
|
510 |
+
target_size = (target_size, target_size)
|
511 |
+
|
512 |
+
ratio_height = target_size[0]/img_size[0]
|
513 |
+
ratio_width = target_size[1]/img_size[1]
|
514 |
+
|
515 |
+
return bboxes[...,]*[ratio_width,ratio_height,ratio_width,ratio_height]
|
516 |
+
|
517 |
+
|
518 |
+
def bbox_crop(bboxes, top, left, height, width):
|
519 |
+
'''crop bbox
|
520 |
+
|
521 |
+
Arguments:
|
522 |
+
img {ndarray} -- image to be croped
|
523 |
+
top {int} -- top size
|
524 |
+
left {int} -- left size
|
525 |
+
height {int} -- croped height
|
526 |
+
width {int} -- croped width
|
527 |
+
'''
|
528 |
+
croped_bboxes = bboxes.copy()
|
529 |
+
|
530 |
+
right = width + left
|
531 |
+
bottom = height + top
|
532 |
+
|
533 |
+
croped_bboxes[..., 0::2] = bboxes[..., 0::2].clip(left, right) - left
|
534 |
+
croped_bboxes[..., 1::2] = bboxes[..., 1::2].clip(top, bottom) - top
|
535 |
+
|
536 |
+
return croped_bboxes
|
537 |
+
|
538 |
+
def bbox_pad(bboxes, padding):
|
539 |
+
if isinstance(padding, int):
|
540 |
+
pad_left = pad_right = pad_top = pad_bottom = padding
|
541 |
+
if isinstance(padding, collections.Iterable) and len(padding) == 2:
|
542 |
+
pad_left = pad_right = padding[0]
|
543 |
+
pad_bottom = pad_top = padding[1]
|
544 |
+
if isinstance(padding, collections.Iterable) and len(padding) == 4:
|
545 |
+
pad_left = padding[0]
|
546 |
+
pad_top = padding[1]
|
547 |
+
pad_right = padding[2]
|
548 |
+
pad_bottom = padding[3]
|
549 |
+
|
550 |
+
pad_bboxes = bboxes.copy()
|
551 |
+
pad_bboxes[..., 0::2] = bboxes[..., 0::2] + pad_left
|
552 |
+
pad_bboxes[..., 1::2] = bboxes[..., 1::2] + pad_top
|
553 |
+
|
554 |
+
return pad_bboxes
|