ImageEnhancement / models /trilinear_test.py
chenzhicun
初始化web demo.
ec08fea
from re import A
import time
from turtle import width
import torch
import torch.nn as nn
import torch.nn.functional as F
##new####
# https://github.com/tedyhabtegebrial/PyTorch-Trilinear-Interpolation
class TrilinearIntepolation(nn.Module):
"""TrilinearIntepolation in PyTorch."""
def __init__(self):
super(TrilinearIntepolation, self).__init__()
def sample_at_integer_locs(self, input_feats, index_tensor):
assert input_feats.ndimension()==5, 'input_feats should be of shape [Batch,F,D,Height,Width]'
assert index_tensor.ndimension()==4, 'index_tensor should be of shape [Batch,Height,Width,3]'
# first sample pixel locations using nearest neighbour interpolation
batch_size, num_chans, num_d, height, width = input_feats.shape
grid_height, grid_width = index_tensor.shape[1],index_tensor.shape[2]
xy_grid = index_tensor[..., 0:2]
# 0:2是包括0但是不包括2的,因此取出来的是最后一个维度的0维和1维
xy_grid[..., 0] = xy_grid[..., 0] - ((width-1.0)/2.0)
xy_grid[..., 0] = xy_grid[..., 0] / ((width-1.0)/2.0)
xy_grid[..., 1] = xy_grid[..., 1] - ((height-1.0)/2.0)
xy_grid[..., 1] = xy_grid[..., 1] / ((height-1.0)/2.0)
xy_grid = torch.clamp(xy_grid, min=-1.0, max=1.0)
#clamp限制每个元素的最大值和最小值
sampled_in_2d = F.grid_sample(input=input_feats.view(batch_size, num_chans*num_d, height, width),
grid=xy_grid, mode='nearest').view(batch_size, num_chans, num_d, grid_height, grid_width)
# grid_sample双线性插值https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html?highlight=grid_sample#torch.nn.functional.grid_sample
# view函数https://blog.csdn.net/york1996/article/details/81949843
z_grid = index_tensor[..., 2].view(batch_size, 1, 1, grid_height, grid_width)
z_grid = z_grid.long().clamp(min=0, max=num_d-1)
# .long()将张量转换为int64类型
z_grid = z_grid.expand(batch_size,num_chans, 1, grid_height, grid_width)
# expand对原张量中维度为1的维度进行扩展 https://blog.csdn.net/weixin_42782150/article/details/108615706
# 本例中是使用expand对dim=1的维度进行扩展,扩展成num_chans
sampled_in_3d = sampled_in_2d.gather(2, z_grid).squeeze(2)
return sampled_in_3d
def forward(self, input_feats, sampling_grid):
assert input_feats.ndimension()==5, 'input_feats should be of shape [B,F,D,H,W]'
assert sampling_grid.ndimension()==4, 'sampling_grid should be of shape [B,H,W,3]'
batch_size, num_chans, num_d, height, width = input_feats.shape
grid_height, grid_width = sampling_grid.shape[1],sampling_grid.shape[2]
# make sure sampling grid lies between -1, 1
sampling_grid = torch.clamp(sampling_grid, min=-1.0, max=1.0)
# map to 0,1
sampling_grid = (sampling_grid+1)/2.0
# Scale grid to floating point pixel locations
scaling_factor = torch.FloatTensor([width-1.0, height-1.0, num_d-1.0]).to(input_feats.device).view(1, 1, 1, 3)
sampling_grid = scaling_factor*sampling_grid
# Now sampling grid is between [0, w-1; 0,h-1; 0,d-1]
x, y, z = torch.split(sampling_grid, split_size_or_sections=1, dim=3)
#这个(x,y,z)是输入的浮点数(在这篇文章中是每个像素点的rgb值)
#这个(x0,y0,z0)是输入的浮点数向下取整
#把sampling_grid维度是3的那个维度切成每份大小为1
x_0, y_0, z_0 = torch.split(sampling_grid.floor(), split_size_or_sections=1, dim=3)
x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0
u, v, w = x-x_0, y-y_0, z-z_0
print("v:",x_0,y_0,z_0)
print("s:",x_0.size(),y_0.size(),z_0.size())
print("size,cat",torch.cat([x_0, y_0, z_0],dim=3).size())
u, v, w = map(lambda x:x.view(batch_size, 1, grid_height, grid_width).expand(
batch_size, num_chans, grid_height, grid_width), [u, v, w])
c_000 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_0, z_0], dim=3))
# torch.cat 函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。
c_001 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_0, z_1], dim=3))
c_010 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_1, z_0], dim=3))
c_011 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_1, z_1], dim=3))
c_100 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_0, z_0], dim=3))
c_101 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_0, z_1], dim=3))
c_110 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_1, z_0], dim=3))
c_111 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_1, z_1], dim=3))
c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
(1.0-u)*(1.0-v)*(w)*c_001 + \
(1.0-u)*(v)*(1.0-w)*c_010 + \
(1.0-u)*(v)*(w)*c_011 + \
(u)*(1.0-v)*(1.0-w)*c_100 + \
(u)*(1.0-v)*(w)*c_101 + \
(u)*(v)*(1.0-w)*c_110 + \
(u)*(v)*(w)*c_111
return c_xyz
# class bing_lut_trilinearInterplt(nn.Module):
# def __init__(self):
# super(bing_lut_trilinearInterplt, self).__init__()
# def test(self,LUT,img_input):
# # batch_size, num_chans, height, width = img_input.shape
# # grid_height, grid_width = LUT.shape[1],LUT.shape[2]
# grid_in=img_input.transpose(1,2).transpose(2,3)
# # 原本img_input NCHW,改成 NHWC
# xy_grid=grid_in[...,0:2]
# yz_grid=grid_in[...,1:3]
# #只取3通道中的第0和第1通道(0:2不含2)
# input_LUT=LUT[:,:,0,:]
# input_LUT_ori=input_LUT.squeeze(2)
# # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了
# input_LUT=input_LUT_ori[...,0:2]
# input_LUT2=input_LUT_ori[...,1:]
# print("input_LUT2.size()",input_LUT2.size())
# # LUT[33,33,2]
# input_LUT=input_LUT.transpose(1,2).transpose(0,1)
# input_LUT2=input_LUT2.transpose(1,2).transpose(0,1)
# # LUT[2,33,33]
# input_LUT=input_LUT.unsqueeze(0)
# input_LUT2=input_LUT2.unsqueeze(0)
# print(input_LUT.size())
# print(input_LUT2.size())
# print(grid_in.size())
# sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest')
# # .view(batch_size, num_chans, num_d, grid_height, grid_width)
# sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest')
# # .view(batch_size, num_chans, num_d, grid_height, grid_width)
# # print("sampled_in_2d.size()",sampled_in_2d.size())
# # print("sampled_in_2d.size()",sampled_in_2d_2.size())
# # # [1,2,2160,3840]
# # print("ss")
# # print(sampled_in_2d.size())
# # print(sampled_in_2d_2.size())
# res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1)
# print(res.size())
# return res
# # z_grid = grid_in[..., 2]
# # print(z_grid.size())
# # # [1,2160,3840]
# # print("sss")
# def gen_Cout_ijk(self,LUT,x_i,y_i,z_i):
# # def gen_Cout_ijk(LUT,x_i,y_i,z_i,channel=3):
# # LUT size [3,33,33,33]
# # x_i,y_i,z_i size [1,1,2160,3840]
# # N=batch_size
# #img_input.size()=[1,3,2160,3840]\
# # LUT.size()=[3,33,33,33]
# # assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)'
# channel=3
# batch_size,_,height,width=x_i.size()
# print(batch_size,height,width)
# output=torch.zeros([batch_size,channel,height,width])
# # 设置输出大小为[1,3,2160,3840]
# if batch_size==1:
# # x_i=x_i.view(height*width)
# # y_i=y_i.view(height*width)
# # z_i=z_i.view(height*width)
# x_i=x_i.view(height*width).long()
# y_i=y_i.view(height*width).long()
# z_i=z_i.view(height*width).long()
# # x_i=x_i.view(1, height*width)
# # y_i=y_i.view(1, height*width)
# # z_i=z_i.view(1, height*width)
# # 2维tensor,[1, 2160*3840]
# # xyz_i=torch.cat([x_i,y_i,z_i],dim=0)
# # # xyz_i 2维tensor,[3, 2160*3840]
# # print("xyz_i.size()",xyz_i.size())
# else:
# print("error:batch size must be 1")
# for i in range(height*width):
# h_index=int(i/width)
# w_index=int(i%width)
# # print(h_index)
# # print(w_index)
# # print(x_i.size())
# # print(batch_size)
# # print(output.size())
# # print(output[0,0,h_index,w_index])
# if(i%10000==0):
# print(i)
# output[batch_size-1,0,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],0]
# output[batch_size-1,1,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],1]
# output[batch_size-1,2,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],2]
# # x_i=x_i.view(batch_size,height*width)
# # y_i=y_i.view(batch_size,height*width)
# # z_i=z_i.view(batch_size,height*width)
# # 1,2160*3840
# return output
# def forward(self, LUT, img_input):
# assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]'
# # N=batch_size
# #img_input.size()=[1,3,2160,3840]\
# # LUT.size()=[3,33,33,33]
# assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)'
# batch_size, num_chans, height, width = img_input.shape
# dim = LUT.shape[1] # M
# img_size=img_input.size()
# Cmax=255.0
# s=Cmax/dim
# r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1)
# # 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分
# #r,g,b.size()=[1,1,2160,3840]
# # r=img_input[:,0,:,:]
# # g=img_input[:,1,:,:]
# # b=img_input[:,2,:,:]
# x=r/s
# y=g/s
# z=b/s
# # tmptmp=self.test(LUT,img_input)
# # x,y,z.size=[1,1,,2160,3840]
# # x_0,y_0,z_0.size=[1,1,,2160,3840]
# # x_1, y_1, z_1.size=[1,1,,2160,3840]
# x_0,y_0,z_0=x.floor(),y.floor(),z.floor()
# x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0
# u, v, w = x-x_0, y-y_0, z-z_0
# # u,v,w.size=[1,1,2160,3840]
# # print("x_0.size",x_0.size())
# c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1))
# print(c_000.size())
# # x_i是顶点,大小为[1,1,2160,3840]
# # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
# c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1))
# c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1))
# c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1))
# c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1))
# c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1))
# c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1))
# c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1))
# # c_000 = self.gen_Cout_ijk(LUT,x_0,y_0,z_0)
# # # x_i是顶点,大小为[1,1,2160,3840]
# # # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
# # c_100 = self.gen_Cout_ijk(LUT,x_1,y_0,z_0)
# # c_010 = self.gen_Cout_ijk(LUT,x_0,y_1,z_0)
# # c_110 = self.gen_Cout_ijk(LUT,x_1,y_1,z_0)
# # c_001 = self.gen_Cout_ijk(LUT,x_0,y_0,z_1)
# # c_101 = self.gen_Cout_ijk(LUT,x_1,y_0,z_1)
# # c_011 = self.gen_Cout_ijk(LUT,x_0,y_1,z_1)
# # c_111 = self.gen_Cout_ijk(LUT,x_1,y_1,z_1)
# c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
# (1.0-u)*(1.0-v)*(w)*c_001 + \
# (1.0-u)*(v)*(1.0-w)*c_010 + \
# (1.0-u)*(v)*(w)*c_011 + \
# (u)*(1.0-v)*(1.0-w)*c_100 + \
# (u)*(1.0-v)*(w)*c_101 + \
# (u)*(v)*(1.0-w)*c_110 + \
# (u)*(v)*(w)*c_111
# # 广播机制,输出[1,3,2160,3840]
# print("c_xyz",c_xyz.size())
# return c_xyz
# # id100 = x_0 + 1.0 + y_0 * dim + z_0 * dim * dim
# # id010 = x_0 + (y_0 + 1.0) * dim + z_0 * dim * dim
# # id110 = x_0 + 1.0 + (y_0 + 1.0) * dim + z_0 * dim * dim
# # id001 = x_0 + y_0 * dim + (z_0 + 1.0) * dim * dim
# # id101 = x_0 + 1.0 + y_0 * dim + (z_0 + 1.0) * dim * dim
# # id011 = x_0 + (y_0 + 1.0) * dim + (z_0 + 1.0) * dim * dim
# # id111 = x_0 + 1.0 + (y_0 + 1.0) * dim + (z_0 + 1.0) * dim * dim
# # w000 = (1.0-u)*(1-v)*(1-w)
# # #大概也许得改成点乘
# # w100 = u*(1-v)*(1-w)
# # w010 = (1-u)*v*(1-w)
# # w110 = u*v*(1-w)
# # w001 = (1-u)*(1-v)*w
# # w101 = u*(1-v)*w
# # w011 = (1-u)*v*w
# # w111 = u*v*w
# # output=
# # print("v:",x_0,y_0,z_0)
# # print("s:",x_0.size(),y_0.size(),z_0.size())
# # u,v,w=u/s,v/s,w/s
# # c_000 = self.gen_Cout_ijk(x_0,y_0,z_0)
# # c_100 = self.gen_Cout_ijk(x_1,y_0,z_0)
# # c_010 = self.gen_Cout_ijk(x_0,y_1,z_0)
# # c_110 = self.gen_Cout_ijk(x_1,y_1,z_0)
# # c_001 = self.gen_Cout_ijk(x_0,y_0,z_1)
# # c_101 = self.gen_Cout_ijk(x_1,y_0,z_1)
# # c_011 = self.gen_Cout_ijk(x_0,y_1,z_1)
# # c_111 = self.gen_Cout_ijk(x_1,y_1,z_1)
# # c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
# # (1.0-u)*(1.0-v)*(w)*c_001 + \
# # (1.0-u)*(v)*(1.0-w)*c_010 + \
# # (1.0-u)*(v)*(w)*c_011 + \
# # (u)*(1.0-v)*(1.0-w)*c_100 + \
# # (u)*(1.0-v)*(w)*c_101 + \
# # (u)*(v)*(1.0-w)*c_110 + \
# # (u)*(v)*(w)*c_111
# # return c_xyz
class Tritri(nn.Module):
def __init__(self):
super(Tritri, self).__init__()
def forward(self,LUT,img):
img = (img - .5) * 2.
# grid_sample expects NxDxHxWx3 (1x1xHxWx3)
img = img.permute(0, 2, 3, 1)[:, None]
# add batch dim to LUT
LUT = LUT[None]
# grid sample
result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True)
# drop added dimensions and permute back
result = result[:, :, 0].permute(0, 2, 3, 1)
return result
class bing_lut_trilinearInterplt(nn.Module):
def __init__(self):
super(bing_lut_trilinearInterplt, self).__init__()
def test(self,LUT,img_input):
# batch_size, num_chans, height, width = img_input.shape
# grid_height, grid_width = LUT.shape[1],LUT.shape[2]
grid_in=img_input.transpose(1,2).transpose(2,3)
# 1
# 原本img_input NCHW,改成 NHWC
xy_grid=grid_in[...,0:2]
yz_grid=grid_in[...,1:3]
# 23
#只取3通道中的第0和第1通道(0:2不含2)
# LUT正确版本应该是[3,33,33,33]
# 在这里弄错成为[33,33,33,3]
input_LUT=LUT[:,:,:,0:1]
input_LUT_ori=input_LUT.squeeze(3)
# 45
# [3,33,33,33]->[3,33,33] 把dim=3的数据丢掉了
# input_LUT=LUT[:,:,0,:]
# input_LUT_ori=input_LUT.squeeze(2)
# # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了
input_LUT=input_LUT_ori[0:2,...]
input_LUT2=input_LUT_ori[1:,...]
input_LUT=input_LUT.unsqueeze(0)
input_LUT2=input_LUT2.unsqueeze(0)
# 6-9
# 都是[1,2,33,33]
# print(input_LUT.size())
# print("dtype:")
# print(input_LUT.dtype)
# print(input_LUT2.dtype)
# print(xy_grid.dtype)
# print(yz_grid.dtype)
# input_LUT.int()
# input_LUT2.int()
# xy_grid.int()
# yz_grid.int()
# # print(grid_in.size())
sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest',align_corners=False)
# .view(batch_size, num_chans, num_d, grid_height, grid_width)
sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest',align_corners=False)
# .view(batch_size, num_chans, num_d, grid_height, grid_width)
# 10
res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1)
# print(res.size())
return res
def forward(self, LUT, img_input):
assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]'
# N=batch_size
#img_input.size()=[1,3,2160,3840]\
# LUT.size()=[3,33,33,33]
assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)'
# batch_size, num_chans, height, width = img_input.shape
dim = LUT.shape[1] # M
# img_size=img_input.size()
# Cmax=1.00001
Cmax=10
s=Cmax/(dim-1.0)
s=torch.Tensor([s])
#谢谢小黄鸭!!#data types int64 and int32 do not match in BroadcastRel
r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1)
# 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分
#r,g,b.size()=[1,1,2160,3840]
# r=img_input[:,0,:,:]
# g=img_input[:,1,:,:]
# b=img_input[:,2,:,:]
s=s.to(r.device)
x=r/s
y=g/s
z=b/s
# tmptmp=self.test(LUT,img_input)
# x,y,z.size=[1,1,,2160,3840]
# x_0,y_0,z_0.size=[1,1,,2160,3840]
# x_1, y_1, z_1.size=[1,1,,2160,3840]
x_0,y_0,z_0=x.floor(),y.floor(),z.floor()
x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0
u, v, w = x-x_0, y-y_0, z-z_0
# u,v,w.size=[1,1,2160,3840]
# print("x_0.size",x_0.size())
c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1))
# print(c_000.size())
# x_i是顶点,大小为[1,1,2160,3840]
# 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1))
c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1))
c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1))
c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1))
c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1))
c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1))
c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1))
c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
(1.0-u)*(1.0-v)*(w)*c_001 + \
(1.0-u)*(v)*(1.0-w)*c_010 + \
(1.0-u)*(v)*(w)*c_011 + \
(u)*(1.0-v)*(1.0-w)*c_100 + \
(u)*(1.0-v)*(w)*c_101 + \
(u)*(v)*(1.0-w)*c_110 + \
(u)*(v)*(w)*c_111
# 广播机制,输出[1,3,2160,3840]
print("c_xyz",c_xyz.size())
return c_xyz
class bing_lut_trilinearInterplt_backup(nn.Module):
def __init__(self):
super(bing_lut_trilinearInterplt, self).__init__()
def test(self,LUT,img_input):
# batch_size, num_chans, height, width = img_input.shape
# grid_height, grid_width = LUT.shape[1],LUT.shape[2]
grid_in=img_input.transpose(1,2).transpose(2,3)
# 1
# 原本img_input NCHW,改成 NHWC
xy_grid=grid_in[...,0:2]
yz_grid=grid_in[...,1:3]
# 23
#只取3通道中的第0和第1通道(0:2不含2)
# LUT正确版本应该是[3,33,33,33]
# 在这里弄错成为[33,33,33,3]
input_LUT=LUT[:,:,:,0:1]
input_LUT_ori=input_LUT.squeeze(3)
# 45
# [3,33,33,33]->[3,33,33] 把dim=3的数据丢掉了
# input_LUT=LUT[:,:,0,:]
# input_LUT_ori=input_LUT.squeeze(2)
# # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了
input_LUT=input_LUT_ori[0:2,...]
input_LUT2=input_LUT_ori[1:,...]
input_LUT=input_LUT.unsqueeze(0)
input_LUT2=input_LUT2.unsqueeze(0)
# 6-9
# 都是[1,2,33,33]
# print(input_LUT.size())
# print("dtype:")
# print(input_LUT.dtype)
# print(input_LUT2.dtype)
# print(xy_grid.dtype)
# print(yz_grid.dtype)
# input_LUT.int()
# input_LUT2.int()
# xy_grid.int()
# yz_grid.int()
# # print(grid_in.size())
sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest')
# .view(batch_size, num_chans, num_d, grid_height, grid_width)
sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest')
# .view(batch_size, num_chans, num_d, grid_height, grid_width)
# 10
res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1)
# print(res.size())
return res
def forward(self, LUT, img_input):
assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]'
# N=batch_size
#img_input.size()=[1,3,2160,3840]\
# LUT.size()=[3,33,33,33]
assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)'
# batch_size, num_chans, height, width = img_input.shape
dim = LUT.shape[1] # M
# img_size=img_input.size()
Cmax=255.0
s=Cmax/dim
s=torch.Tensor([s])
#谢谢小黄鸭!!#data types int64 and int32 do not match in BroadcastRel
r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1)
# 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分
#r,g,b.size()=[1,1,2160,3840]
# r=img_input[:,0,:,:]
# g=img_input[:,1,:,:]
# b=img_input[:,2,:,:]
x=r/s
y=g/s
z=b/s
# tmptmp=self.test(LUT,img_input)
# x,y,z.size=[1,1,,2160,3840]
# x_0,y_0,z_0.size=[1,1,,2160,3840]
# x_1, y_1, z_1.size=[1,1,,2160,3840]
x_0,y_0,z_0=x.floor(),y.floor(),z.floor()
x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0
u, v, w = x-x_0, y-y_0, z-z_0
# u,v,w.size=[1,1,2160,3840]
# print("x_0.size",x_0.size())
c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1))
# print(c_000.size())
# x_i是顶点,大小为[1,1,2160,3840]
# 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1))
c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1))
c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1))
c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1))
c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1))
c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1))
c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1))
# c_000 = self.gen_Cout_ijk(LUT,x_0,y_0,z_0)
# # x_i是顶点,大小为[1,1,2160,3840]
# # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840]
# c_100 = self.gen_Cout_ijk(LUT,x_1,y_0,z_0)
# c_010 = self.gen_Cout_ijk(LUT,x_0,y_1,z_0)
# c_110 = self.gen_Cout_ijk(LUT,x_1,y_1,z_0)
# c_001 = self.gen_Cout_ijk(LUT,x_0,y_0,z_1)
# c_101 = self.gen_Cout_ijk(LUT,x_1,y_0,z_1)
# c_011 = self.gen_Cout_ijk(LUT,x_0,y_1,z_1)
# c_111 = self.gen_Cout_ijk(LUT,x_1,y_1,z_1)
c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
(1.0-u)*(1.0-v)*(w)*c_001 + \
(1.0-u)*(v)*(1.0-w)*c_010 + \
(1.0-u)*(v)*(w)*c_011 + \
(u)*(1.0-v)*(1.0-w)*c_100 + \
(u)*(1.0-v)*(w)*c_101 + \
(u)*(v)*(1.0-w)*c_110 + \
(u)*(v)*(w)*c_111
# 广播机制,输出[1,3,2160,3840]
print("c_xyz",c_xyz.size())
return c_xyz
# @staticmethod
# def backward(ctx, lut_grad, x_grad):
# lut, x, int_package, float_package = ctx.saved_variables
# dim, shift, W, H, batch = int_package
# dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch)
# binsize = float(float_package[0])
# assert 1 == trilinear.backward(x,
# x_grad,
# lut_grad,
# dim,
# shift,
# binsize,
# W,
# H,
# batch)
# return lut_grad, x_grad
class Tri(nn.Module):
def __init__(self):
super(Tri,self).__init__()
if __name__=='__main__':
# input_features: shape [B, num_channels, depth, height, width]
# sampling_grid: shape [B,depth, height, 3]
data = torch.rand(1, 32, 16, 128, 128)
# data = torch.rand(1, 3, 16, 128, 128)
sampling_grid = (torch.rand(1, 256, 256, 3) - 0.5)*2.0
data = data.float().cuda(0)
sampling_grid = sampling_grid.float().cuda(0)
trilinear_interpolation = TrilinearIntepolation().cuda(0)
# LUT.type() torch.cuda.FloatTensor
# LUT.size() torch.Size([3, 33, 33, 33])
# img: torch.Size([1, 3, 2160, 3840])
data2 = torch.rand(1, 3,2160,3840)
# LUT2 = torch.rand(33,33,33,3)
LUT2 = torch.rand(3,33,33,33)
trilinear_interpolation2 = bing_lut_trilinearInterplt()
t_start = time.time()
interp_data2=trilinear_interpolation2(LUT2,data2)
# interpolated_data = trilinear_interpolation(data, sampling_grid)
# print(interpolated_data.shape)
torch.cuda.synchronize()
print('time per iteration ', time.time()-t_start)
# for i in range(100):
# t_start = time.time()
# interpolated_data = trilinear_interpolation(data, sampling_grid)
# print(interpolated_data.shape)
# torch.cuda.synchronize()
# print('time per iteration ', time.time()-t_start)