from re import A import time from turtle import width import torch import torch.nn as nn import torch.nn.functional as F ##new#### # https://github.com/tedyhabtegebrial/PyTorch-Trilinear-Interpolation class TrilinearIntepolation(nn.Module): """TrilinearIntepolation in PyTorch.""" def __init__(self): super(TrilinearIntepolation, self).__init__() def sample_at_integer_locs(self, input_feats, index_tensor): assert input_feats.ndimension()==5, 'input_feats should be of shape [Batch,F,D,Height,Width]' assert index_tensor.ndimension()==4, 'index_tensor should be of shape [Batch,Height,Width,3]' # first sample pixel locations using nearest neighbour interpolation batch_size, num_chans, num_d, height, width = input_feats.shape grid_height, grid_width = index_tensor.shape[1],index_tensor.shape[2] xy_grid = index_tensor[..., 0:2] # 0:2是包括0但是不包括2的,因此取出来的是最后一个维度的0维和1维 xy_grid[..., 0] = xy_grid[..., 0] - ((width-1.0)/2.0) xy_grid[..., 0] = xy_grid[..., 0] / ((width-1.0)/2.0) xy_grid[..., 1] = xy_grid[..., 1] - ((height-1.0)/2.0) xy_grid[..., 1] = xy_grid[..., 1] / ((height-1.0)/2.0) xy_grid = torch.clamp(xy_grid, min=-1.0, max=1.0) #clamp限制每个元素的最大值和最小值 sampled_in_2d = F.grid_sample(input=input_feats.view(batch_size, num_chans*num_d, height, width), grid=xy_grid, mode='nearest').view(batch_size, num_chans, num_d, grid_height, grid_width) # grid_sample双线性插值https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html?highlight=grid_sample#torch.nn.functional.grid_sample # view函数https://blog.csdn.net/york1996/article/details/81949843 z_grid = index_tensor[..., 2].view(batch_size, 1, 1, grid_height, grid_width) z_grid = z_grid.long().clamp(min=0, max=num_d-1) # .long()将张量转换为int64类型 z_grid = z_grid.expand(batch_size,num_chans, 1, grid_height, grid_width) # expand对原张量中维度为1的维度进行扩展 https://blog.csdn.net/weixin_42782150/article/details/108615706 # 本例中是使用expand对dim=1的维度进行扩展,扩展成num_chans sampled_in_3d = sampled_in_2d.gather(2, z_grid).squeeze(2) return sampled_in_3d def forward(self, input_feats, sampling_grid): assert input_feats.ndimension()==5, 'input_feats should be of shape [B,F,D,H,W]' assert sampling_grid.ndimension()==4, 'sampling_grid should be of shape [B,H,W,3]' batch_size, num_chans, num_d, height, width = input_feats.shape grid_height, grid_width = sampling_grid.shape[1],sampling_grid.shape[2] # make sure sampling grid lies between -1, 1 sampling_grid = torch.clamp(sampling_grid, min=-1.0, max=1.0) # map to 0,1 sampling_grid = (sampling_grid+1)/2.0 # Scale grid to floating point pixel locations scaling_factor = torch.FloatTensor([width-1.0, height-1.0, num_d-1.0]).to(input_feats.device).view(1, 1, 1, 3) sampling_grid = scaling_factor*sampling_grid # Now sampling grid is between [0, w-1; 0,h-1; 0,d-1] x, y, z = torch.split(sampling_grid, split_size_or_sections=1, dim=3) #这个(x,y,z)是输入的浮点数(在这篇文章中是每个像素点的rgb值) #这个(x0,y0,z0)是输入的浮点数向下取整 #把sampling_grid维度是3的那个维度切成每份大小为1 x_0, y_0, z_0 = torch.split(sampling_grid.floor(), split_size_or_sections=1, dim=3) x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0 u, v, w = x-x_0, y-y_0, z-z_0 print("v:",x_0,y_0,z_0) print("s:",x_0.size(),y_0.size(),z_0.size()) print("size,cat",torch.cat([x_0, y_0, z_0],dim=3).size()) u, v, w = map(lambda x:x.view(batch_size, 1, grid_height, grid_width).expand( batch_size, num_chans, grid_height, grid_width), [u, v, w]) c_000 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_0, z_0], dim=3)) # torch.cat 函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。 c_001 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_0, z_1], dim=3)) c_010 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_1, z_0], dim=3)) c_011 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_1, z_1], dim=3)) c_100 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_0, z_0], dim=3)) c_101 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_0, z_1], dim=3)) c_110 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_1, z_0], dim=3)) c_111 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_1, z_1], dim=3)) c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \ (1.0-u)*(1.0-v)*(w)*c_001 + \ (1.0-u)*(v)*(1.0-w)*c_010 + \ (1.0-u)*(v)*(w)*c_011 + \ (u)*(1.0-v)*(1.0-w)*c_100 + \ (u)*(1.0-v)*(w)*c_101 + \ (u)*(v)*(1.0-w)*c_110 + \ (u)*(v)*(w)*c_111 return c_xyz # class bing_lut_trilinearInterplt(nn.Module): # def __init__(self): # super(bing_lut_trilinearInterplt, self).__init__() # def test(self,LUT,img_input): # # batch_size, num_chans, height, width = img_input.shape # # grid_height, grid_width = LUT.shape[1],LUT.shape[2] # grid_in=img_input.transpose(1,2).transpose(2,3) # # 原本img_input NCHW,改成 NHWC # xy_grid=grid_in[...,0:2] # yz_grid=grid_in[...,1:3] # #只取3通道中的第0和第1通道(0:2不含2) # input_LUT=LUT[:,:,0,:] # input_LUT_ori=input_LUT.squeeze(2) # # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了 # input_LUT=input_LUT_ori[...,0:2] # input_LUT2=input_LUT_ori[...,1:] # print("input_LUT2.size()",input_LUT2.size()) # # LUT[33,33,2] # input_LUT=input_LUT.transpose(1,2).transpose(0,1) # input_LUT2=input_LUT2.transpose(1,2).transpose(0,1) # # LUT[2,33,33] # input_LUT=input_LUT.unsqueeze(0) # input_LUT2=input_LUT2.unsqueeze(0) # print(input_LUT.size()) # print(input_LUT2.size()) # print(grid_in.size()) # sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest') # # .view(batch_size, num_chans, num_d, grid_height, grid_width) # sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest') # # .view(batch_size, num_chans, num_d, grid_height, grid_width) # # print("sampled_in_2d.size()",sampled_in_2d.size()) # # print("sampled_in_2d.size()",sampled_in_2d_2.size()) # # # [1,2,2160,3840] # # print("ss") # # print(sampled_in_2d.size()) # # print(sampled_in_2d_2.size()) # res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1) # print(res.size()) # return res # # z_grid = grid_in[..., 2] # # print(z_grid.size()) # # # [1,2160,3840] # # print("sss") # def gen_Cout_ijk(self,LUT,x_i,y_i,z_i): # # def gen_Cout_ijk(LUT,x_i,y_i,z_i,channel=3): # # LUT size [3,33,33,33] # # x_i,y_i,z_i size [1,1,2160,3840] # # N=batch_size # #img_input.size()=[1,3,2160,3840]\ # # LUT.size()=[3,33,33,33] # # assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)' # channel=3 # batch_size,_,height,width=x_i.size() # print(batch_size,height,width) # output=torch.zeros([batch_size,channel,height,width]) # # 设置输出大小为[1,3,2160,3840] # if batch_size==1: # # x_i=x_i.view(height*width) # # y_i=y_i.view(height*width) # # z_i=z_i.view(height*width) # x_i=x_i.view(height*width).long() # y_i=y_i.view(height*width).long() # z_i=z_i.view(height*width).long() # # x_i=x_i.view(1, height*width) # # y_i=y_i.view(1, height*width) # # z_i=z_i.view(1, height*width) # # 2维tensor,[1, 2160*3840] # # xyz_i=torch.cat([x_i,y_i,z_i],dim=0) # # # xyz_i 2维tensor,[3, 2160*3840] # # print("xyz_i.size()",xyz_i.size()) # else: # print("error:batch size must be 1") # for i in range(height*width): # h_index=int(i/width) # w_index=int(i%width) # # print(h_index) # # print(w_index) # # print(x_i.size()) # # print(batch_size) # # print(output.size()) # # print(output[0,0,h_index,w_index]) # if(i%10000==0): # print(i) # output[batch_size-1,0,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],0] # output[batch_size-1,1,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],1] # output[batch_size-1,2,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],2] # # x_i=x_i.view(batch_size,height*width) # # y_i=y_i.view(batch_size,height*width) # # z_i=z_i.view(batch_size,height*width) # # 1,2160*3840 # return output # def forward(self, LUT, img_input): # assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]' # # N=batch_size # #img_input.size()=[1,3,2160,3840]\ # # LUT.size()=[3,33,33,33] # assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)' # batch_size, num_chans, height, width = img_input.shape # dim = LUT.shape[1] # M # img_size=img_input.size() # Cmax=255.0 # s=Cmax/dim # r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1) # # 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分 # #r,g,b.size()=[1,1,2160,3840] # # r=img_input[:,0,:,:] # # g=img_input[:,1,:,:] # # b=img_input[:,2,:,:] # x=r/s # y=g/s # z=b/s # # tmptmp=self.test(LUT,img_input) # # x,y,z.size=[1,1,,2160,3840] # # x_0,y_0,z_0.size=[1,1,,2160,3840] # # x_1, y_1, z_1.size=[1,1,,2160,3840] # x_0,y_0,z_0=x.floor(),y.floor(),z.floor() # x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0 # u, v, w = x-x_0, y-y_0, z-z_0 # # u,v,w.size=[1,1,2160,3840] # # print("x_0.size",x_0.size()) # c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1)) # print(c_000.size()) # # x_i是顶点,大小为[1,1,2160,3840] # # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840] # c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1)) # c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1)) # c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1)) # c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1)) # c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1)) # c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1)) # c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1)) # # c_000 = self.gen_Cout_ijk(LUT,x_0,y_0,z_0) # # # x_i是顶点,大小为[1,1,2160,3840] # # # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840] # # c_100 = self.gen_Cout_ijk(LUT,x_1,y_0,z_0) # # c_010 = self.gen_Cout_ijk(LUT,x_0,y_1,z_0) # # c_110 = self.gen_Cout_ijk(LUT,x_1,y_1,z_0) # # c_001 = self.gen_Cout_ijk(LUT,x_0,y_0,z_1) # # c_101 = self.gen_Cout_ijk(LUT,x_1,y_0,z_1) # # c_011 = self.gen_Cout_ijk(LUT,x_0,y_1,z_1) # # c_111 = self.gen_Cout_ijk(LUT,x_1,y_1,z_1) # c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \ # (1.0-u)*(1.0-v)*(w)*c_001 + \ # (1.0-u)*(v)*(1.0-w)*c_010 + \ # (1.0-u)*(v)*(w)*c_011 + \ # (u)*(1.0-v)*(1.0-w)*c_100 + \ # (u)*(1.0-v)*(w)*c_101 + \ # (u)*(v)*(1.0-w)*c_110 + \ # (u)*(v)*(w)*c_111 # # 广播机制,输出[1,3,2160,3840] # print("c_xyz",c_xyz.size()) # return c_xyz # # id100 = x_0 + 1.0 + y_0 * dim + z_0 * dim * dim # # id010 = x_0 + (y_0 + 1.0) * dim + z_0 * dim * dim # # id110 = x_0 + 1.0 + (y_0 + 1.0) * dim + z_0 * dim * dim # # id001 = x_0 + y_0 * dim + (z_0 + 1.0) * dim * dim # # id101 = x_0 + 1.0 + y_0 * dim + (z_0 + 1.0) * dim * dim # # id011 = x_0 + (y_0 + 1.0) * dim + (z_0 + 1.0) * dim * dim # # id111 = x_0 + 1.0 + (y_0 + 1.0) * dim + (z_0 + 1.0) * dim * dim # # w000 = (1.0-u)*(1-v)*(1-w) # # #大概也许得改成点乘 # # w100 = u*(1-v)*(1-w) # # w010 = (1-u)*v*(1-w) # # w110 = u*v*(1-w) # # w001 = (1-u)*(1-v)*w # # w101 = u*(1-v)*w # # w011 = (1-u)*v*w # # w111 = u*v*w # # output= # # print("v:",x_0,y_0,z_0) # # print("s:",x_0.size(),y_0.size(),z_0.size()) # # u,v,w=u/s,v/s,w/s # # c_000 = self.gen_Cout_ijk(x_0,y_0,z_0) # # c_100 = self.gen_Cout_ijk(x_1,y_0,z_0) # # c_010 = self.gen_Cout_ijk(x_0,y_1,z_0) # # c_110 = self.gen_Cout_ijk(x_1,y_1,z_0) # # c_001 = self.gen_Cout_ijk(x_0,y_0,z_1) # # c_101 = self.gen_Cout_ijk(x_1,y_0,z_1) # # c_011 = self.gen_Cout_ijk(x_0,y_1,z_1) # # c_111 = self.gen_Cout_ijk(x_1,y_1,z_1) # # c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \ # # (1.0-u)*(1.0-v)*(w)*c_001 + \ # # (1.0-u)*(v)*(1.0-w)*c_010 + \ # # (1.0-u)*(v)*(w)*c_011 + \ # # (u)*(1.0-v)*(1.0-w)*c_100 + \ # # (u)*(1.0-v)*(w)*c_101 + \ # # (u)*(v)*(1.0-w)*c_110 + \ # # (u)*(v)*(w)*c_111 # # return c_xyz class Tritri(nn.Module): def __init__(self): super(Tritri, self).__init__() def forward(self,LUT,img): img = (img - .5) * 2. # grid_sample expects NxDxHxWx3 (1x1xHxWx3) img = img.permute(0, 2, 3, 1)[:, None] # add batch dim to LUT LUT = LUT[None] # grid sample result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True) # drop added dimensions and permute back result = result[:, :, 0].permute(0, 2, 3, 1) return result class bing_lut_trilinearInterplt(nn.Module): def __init__(self): super(bing_lut_trilinearInterplt, self).__init__() def test(self,LUT,img_input): # batch_size, num_chans, height, width = img_input.shape # grid_height, grid_width = LUT.shape[1],LUT.shape[2] grid_in=img_input.transpose(1,2).transpose(2,3) # 1 # 原本img_input NCHW,改成 NHWC xy_grid=grid_in[...,0:2] yz_grid=grid_in[...,1:3] # 23 #只取3通道中的第0和第1通道(0:2不含2) # LUT正确版本应该是[3,33,33,33] # 在这里弄错成为[33,33,33,3] input_LUT=LUT[:,:,:,0:1] input_LUT_ori=input_LUT.squeeze(3) # 45 # [3,33,33,33]->[3,33,33] 把dim=3的数据丢掉了 # input_LUT=LUT[:,:,0,:] # input_LUT_ori=input_LUT.squeeze(2) # # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了 input_LUT=input_LUT_ori[0:2,...] input_LUT2=input_LUT_ori[1:,...] input_LUT=input_LUT.unsqueeze(0) input_LUT2=input_LUT2.unsqueeze(0) # 6-9 # 都是[1,2,33,33] # print(input_LUT.size()) # print("dtype:") # print(input_LUT.dtype) # print(input_LUT2.dtype) # print(xy_grid.dtype) # print(yz_grid.dtype) # input_LUT.int() # input_LUT2.int() # xy_grid.int() # yz_grid.int() # # print(grid_in.size()) sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest',align_corners=False) # .view(batch_size, num_chans, num_d, grid_height, grid_width) sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest',align_corners=False) # .view(batch_size, num_chans, num_d, grid_height, grid_width) # 10 res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1) # print(res.size()) return res def forward(self, LUT, img_input): assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]' # N=batch_size #img_input.size()=[1,3,2160,3840]\ # LUT.size()=[3,33,33,33] assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)' # batch_size, num_chans, height, width = img_input.shape dim = LUT.shape[1] # M # img_size=img_input.size() # Cmax=1.00001 Cmax=10 s=Cmax/(dim-1.0) s=torch.Tensor([s]) #谢谢小黄鸭!!#data types int64 and int32 do not match in BroadcastRel r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1) # 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分 #r,g,b.size()=[1,1,2160,3840] # r=img_input[:,0,:,:] # g=img_input[:,1,:,:] # b=img_input[:,2,:,:] s=s.to(r.device) x=r/s y=g/s z=b/s # tmptmp=self.test(LUT,img_input) # x,y,z.size=[1,1,,2160,3840] # x_0,y_0,z_0.size=[1,1,,2160,3840] # x_1, y_1, z_1.size=[1,1,,2160,3840] x_0,y_0,z_0=x.floor(),y.floor(),z.floor() x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0 u, v, w = x-x_0, y-y_0, z-z_0 # u,v,w.size=[1,1,2160,3840] # print("x_0.size",x_0.size()) c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1)) # print(c_000.size()) # x_i是顶点,大小为[1,1,2160,3840] # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840] c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1)) c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1)) c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1)) c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1)) c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1)) c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1)) c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1)) c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \ (1.0-u)*(1.0-v)*(w)*c_001 + \ (1.0-u)*(v)*(1.0-w)*c_010 + \ (1.0-u)*(v)*(w)*c_011 + \ (u)*(1.0-v)*(1.0-w)*c_100 + \ (u)*(1.0-v)*(w)*c_101 + \ (u)*(v)*(1.0-w)*c_110 + \ (u)*(v)*(w)*c_111 # 广播机制,输出[1,3,2160,3840] print("c_xyz",c_xyz.size()) return c_xyz class bing_lut_trilinearInterplt_backup(nn.Module): def __init__(self): super(bing_lut_trilinearInterplt, self).__init__() def test(self,LUT,img_input): # batch_size, num_chans, height, width = img_input.shape # grid_height, grid_width = LUT.shape[1],LUT.shape[2] grid_in=img_input.transpose(1,2).transpose(2,3) # 1 # 原本img_input NCHW,改成 NHWC xy_grid=grid_in[...,0:2] yz_grid=grid_in[...,1:3] # 23 #只取3通道中的第0和第1通道(0:2不含2) # LUT正确版本应该是[3,33,33,33] # 在这里弄错成为[33,33,33,3] input_LUT=LUT[:,:,:,0:1] input_LUT_ori=input_LUT.squeeze(3) # 45 # [3,33,33,33]->[3,33,33] 把dim=3的数据丢掉了 # input_LUT=LUT[:,:,0,:] # input_LUT_ori=input_LUT.squeeze(2) # # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了 input_LUT=input_LUT_ori[0:2,...] input_LUT2=input_LUT_ori[1:,...] input_LUT=input_LUT.unsqueeze(0) input_LUT2=input_LUT2.unsqueeze(0) # 6-9 # 都是[1,2,33,33] # print(input_LUT.size()) # print("dtype:") # print(input_LUT.dtype) # print(input_LUT2.dtype) # print(xy_grid.dtype) # print(yz_grid.dtype) # input_LUT.int() # input_LUT2.int() # xy_grid.int() # yz_grid.int() # # print(grid_in.size()) sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest') # .view(batch_size, num_chans, num_d, grid_height, grid_width) sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest') # .view(batch_size, num_chans, num_d, grid_height, grid_width) # 10 res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1) # print(res.size()) return res def forward(self, LUT, img_input): assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]' # N=batch_size #img_input.size()=[1,3,2160,3840]\ # LUT.size()=[3,33,33,33] assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)' # batch_size, num_chans, height, width = img_input.shape dim = LUT.shape[1] # M # img_size=img_input.size() Cmax=255.0 s=Cmax/dim s=torch.Tensor([s]) #谢谢小黄鸭!!#data types int64 and int32 do not match in BroadcastRel r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1) # 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分 #r,g,b.size()=[1,1,2160,3840] # r=img_input[:,0,:,:] # g=img_input[:,1,:,:] # b=img_input[:,2,:,:] x=r/s y=g/s z=b/s # tmptmp=self.test(LUT,img_input) # x,y,z.size=[1,1,,2160,3840] # x_0,y_0,z_0.size=[1,1,,2160,3840] # x_1, y_1, z_1.size=[1,1,,2160,3840] x_0,y_0,z_0=x.floor(),y.floor(),z.floor() x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0 u, v, w = x-x_0, y-y_0, z-z_0 # u,v,w.size=[1,1,2160,3840] # print("x_0.size",x_0.size()) c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1)) # print(c_000.size()) # x_i是顶点,大小为[1,1,2160,3840] # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840] c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1)) c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1)) c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1)) c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1)) c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1)) c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1)) c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1)) # c_000 = self.gen_Cout_ijk(LUT,x_0,y_0,z_0) # # x_i是顶点,大小为[1,1,2160,3840] # # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840] # c_100 = self.gen_Cout_ijk(LUT,x_1,y_0,z_0) # c_010 = self.gen_Cout_ijk(LUT,x_0,y_1,z_0) # c_110 = self.gen_Cout_ijk(LUT,x_1,y_1,z_0) # c_001 = self.gen_Cout_ijk(LUT,x_0,y_0,z_1) # c_101 = self.gen_Cout_ijk(LUT,x_1,y_0,z_1) # c_011 = self.gen_Cout_ijk(LUT,x_0,y_1,z_1) # c_111 = self.gen_Cout_ijk(LUT,x_1,y_1,z_1) c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \ (1.0-u)*(1.0-v)*(w)*c_001 + \ (1.0-u)*(v)*(1.0-w)*c_010 + \ (1.0-u)*(v)*(w)*c_011 + \ (u)*(1.0-v)*(1.0-w)*c_100 + \ (u)*(1.0-v)*(w)*c_101 + \ (u)*(v)*(1.0-w)*c_110 + \ (u)*(v)*(w)*c_111 # 广播机制,输出[1,3,2160,3840] print("c_xyz",c_xyz.size()) return c_xyz # @staticmethod # def backward(ctx, lut_grad, x_grad): # lut, x, int_package, float_package = ctx.saved_variables # dim, shift, W, H, batch = int_package # dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch) # binsize = float(float_package[0]) # assert 1 == trilinear.backward(x, # x_grad, # lut_grad, # dim, # shift, # binsize, # W, # H, # batch) # return lut_grad, x_grad class Tri(nn.Module): def __init__(self): super(Tri,self).__init__() if __name__=='__main__': # input_features: shape [B, num_channels, depth, height, width] # sampling_grid: shape [B,depth, height, 3] data = torch.rand(1, 32, 16, 128, 128) # data = torch.rand(1, 3, 16, 128, 128) sampling_grid = (torch.rand(1, 256, 256, 3) - 0.5)*2.0 data = data.float().cuda(0) sampling_grid = sampling_grid.float().cuda(0) trilinear_interpolation = TrilinearIntepolation().cuda(0) # LUT.type() torch.cuda.FloatTensor # LUT.size() torch.Size([3, 33, 33, 33]) # img: torch.Size([1, 3, 2160, 3840]) data2 = torch.rand(1, 3,2160,3840) # LUT2 = torch.rand(33,33,33,3) LUT2 = torch.rand(3,33,33,33) trilinear_interpolation2 = bing_lut_trilinearInterplt() t_start = time.time() interp_data2=trilinear_interpolation2(LUT2,data2) # interpolated_data = trilinear_interpolation(data, sampling_grid) # print(interpolated_data.shape) torch.cuda.synchronize() print('time per iteration ', time.time()-t_start) # for i in range(100): # t_start = time.time() # interpolated_data = trilinear_interpolation(data, sampling_grid) # print(interpolated_data.shape) # torch.cuda.synchronize() # print('time per iteration ', time.time()-t_start)