Spaces:
Runtime error
Runtime error
from re import A | |
import time | |
from turtle import width | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
##new#### | |
# https://github.com/tedyhabtegebrial/PyTorch-Trilinear-Interpolation | |
class TrilinearIntepolation(nn.Module): | |
"""TrilinearIntepolation in PyTorch.""" | |
def __init__(self): | |
super(TrilinearIntepolation, self).__init__() | |
def sample_at_integer_locs(self, input_feats, index_tensor): | |
assert input_feats.ndimension()==5, 'input_feats should be of shape [Batch,F,D,Height,Width]' | |
assert index_tensor.ndimension()==4, 'index_tensor should be of shape [Batch,Height,Width,3]' | |
# first sample pixel locations using nearest neighbour interpolation | |
batch_size, num_chans, num_d, height, width = input_feats.shape | |
grid_height, grid_width = index_tensor.shape[1],index_tensor.shape[2] | |
xy_grid = index_tensor[..., 0:2] | |
# 0:2是包括0但是不包括2的,因此取出来的是最后一个维度的0维和1维 | |
xy_grid[..., 0] = xy_grid[..., 0] - ((width-1.0)/2.0) | |
xy_grid[..., 0] = xy_grid[..., 0] / ((width-1.0)/2.0) | |
xy_grid[..., 1] = xy_grid[..., 1] - ((height-1.0)/2.0) | |
xy_grid[..., 1] = xy_grid[..., 1] / ((height-1.0)/2.0) | |
xy_grid = torch.clamp(xy_grid, min=-1.0, max=1.0) | |
#clamp限制每个元素的最大值和最小值 | |
sampled_in_2d = F.grid_sample(input=input_feats.view(batch_size, num_chans*num_d, height, width), | |
grid=xy_grid, mode='nearest').view(batch_size, num_chans, num_d, grid_height, grid_width) | |
# grid_sample双线性插值https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html?highlight=grid_sample#torch.nn.functional.grid_sample | |
# view函数https://blog.csdn.net/york1996/article/details/81949843 | |
z_grid = index_tensor[..., 2].view(batch_size, 1, 1, grid_height, grid_width) | |
z_grid = z_grid.long().clamp(min=0, max=num_d-1) | |
# .long()将张量转换为int64类型 | |
z_grid = z_grid.expand(batch_size,num_chans, 1, grid_height, grid_width) | |
# expand对原张量中维度为1的维度进行扩展 https://blog.csdn.net/weixin_42782150/article/details/108615706 | |
# 本例中是使用expand对dim=1的维度进行扩展,扩展成num_chans | |
sampled_in_3d = sampled_in_2d.gather(2, z_grid).squeeze(2) | |
return sampled_in_3d | |
def forward(self, input_feats, sampling_grid): | |
assert input_feats.ndimension()==5, 'input_feats should be of shape [B,F,D,H,W]' | |
assert sampling_grid.ndimension()==4, 'sampling_grid should be of shape [B,H,W,3]' | |
batch_size, num_chans, num_d, height, width = input_feats.shape | |
grid_height, grid_width = sampling_grid.shape[1],sampling_grid.shape[2] | |
# make sure sampling grid lies between -1, 1 | |
sampling_grid = torch.clamp(sampling_grid, min=-1.0, max=1.0) | |
# map to 0,1 | |
sampling_grid = (sampling_grid+1)/2.0 | |
# Scale grid to floating point pixel locations | |
scaling_factor = torch.FloatTensor([width-1.0, height-1.0, num_d-1.0]).to(input_feats.device).view(1, 1, 1, 3) | |
sampling_grid = scaling_factor*sampling_grid | |
# Now sampling grid is between [0, w-1; 0,h-1; 0,d-1] | |
x, y, z = torch.split(sampling_grid, split_size_or_sections=1, dim=3) | |
#这个(x,y,z)是输入的浮点数(在这篇文章中是每个像素点的rgb值) | |
#这个(x0,y0,z0)是输入的浮点数向下取整 | |
#把sampling_grid维度是3的那个维度切成每份大小为1 | |
x_0, y_0, z_0 = torch.split(sampling_grid.floor(), split_size_or_sections=1, dim=3) | |
x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0 | |
u, v, w = x-x_0, y-y_0, z-z_0 | |
print("v:",x_0,y_0,z_0) | |
print("s:",x_0.size(),y_0.size(),z_0.size()) | |
print("size,cat",torch.cat([x_0, y_0, z_0],dim=3).size()) | |
u, v, w = map(lambda x:x.view(batch_size, 1, grid_height, grid_width).expand( | |
batch_size, num_chans, grid_height, grid_width), [u, v, w]) | |
c_000 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_0, z_0], dim=3)) | |
# torch.cat 函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。 | |
c_001 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_0, z_1], dim=3)) | |
c_010 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_1, z_0], dim=3)) | |
c_011 = self.sample_at_integer_locs(input_feats, torch.cat([x_0, y_1, z_1], dim=3)) | |
c_100 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_0, z_0], dim=3)) | |
c_101 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_0, z_1], dim=3)) | |
c_110 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_1, z_0], dim=3)) | |
c_111 = self.sample_at_integer_locs(input_feats, torch.cat([x_1, y_1, z_1], dim=3)) | |
c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \ | |
(1.0-u)*(1.0-v)*(w)*c_001 + \ | |
(1.0-u)*(v)*(1.0-w)*c_010 + \ | |
(1.0-u)*(v)*(w)*c_011 + \ | |
(u)*(1.0-v)*(1.0-w)*c_100 + \ | |
(u)*(1.0-v)*(w)*c_101 + \ | |
(u)*(v)*(1.0-w)*c_110 + \ | |
(u)*(v)*(w)*c_111 | |
return c_xyz | |
# class bing_lut_trilinearInterplt(nn.Module): | |
# def __init__(self): | |
# super(bing_lut_trilinearInterplt, self).__init__() | |
# def test(self,LUT,img_input): | |
# # batch_size, num_chans, height, width = img_input.shape | |
# # grid_height, grid_width = LUT.shape[1],LUT.shape[2] | |
# grid_in=img_input.transpose(1,2).transpose(2,3) | |
# # 原本img_input NCHW,改成 NHWC | |
# xy_grid=grid_in[...,0:2] | |
# yz_grid=grid_in[...,1:3] | |
# #只取3通道中的第0和第1通道(0:2不含2) | |
# input_LUT=LUT[:,:,0,:] | |
# input_LUT_ori=input_LUT.squeeze(2) | |
# # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了 | |
# input_LUT=input_LUT_ori[...,0:2] | |
# input_LUT2=input_LUT_ori[...,1:] | |
# print("input_LUT2.size()",input_LUT2.size()) | |
# # LUT[33,33,2] | |
# input_LUT=input_LUT.transpose(1,2).transpose(0,1) | |
# input_LUT2=input_LUT2.transpose(1,2).transpose(0,1) | |
# # LUT[2,33,33] | |
# input_LUT=input_LUT.unsqueeze(0) | |
# input_LUT2=input_LUT2.unsqueeze(0) | |
# print(input_LUT.size()) | |
# print(input_LUT2.size()) | |
# print(grid_in.size()) | |
# sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest') | |
# # .view(batch_size, num_chans, num_d, grid_height, grid_width) | |
# sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest') | |
# # .view(batch_size, num_chans, num_d, grid_height, grid_width) | |
# # print("sampled_in_2d.size()",sampled_in_2d.size()) | |
# # print("sampled_in_2d.size()",sampled_in_2d_2.size()) | |
# # # [1,2,2160,3840] | |
# # print("ss") | |
# # print(sampled_in_2d.size()) | |
# # print(sampled_in_2d_2.size()) | |
# res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1) | |
# print(res.size()) | |
# return res | |
# # z_grid = grid_in[..., 2] | |
# # print(z_grid.size()) | |
# # # [1,2160,3840] | |
# # print("sss") | |
# def gen_Cout_ijk(self,LUT,x_i,y_i,z_i): | |
# # def gen_Cout_ijk(LUT,x_i,y_i,z_i,channel=3): | |
# # LUT size [3,33,33,33] | |
# # x_i,y_i,z_i size [1,1,2160,3840] | |
# # N=batch_size | |
# #img_input.size()=[1,3,2160,3840]\ | |
# # LUT.size()=[3,33,33,33] | |
# # assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)' | |
# channel=3 | |
# batch_size,_,height,width=x_i.size() | |
# print(batch_size,height,width) | |
# output=torch.zeros([batch_size,channel,height,width]) | |
# # 设置输出大小为[1,3,2160,3840] | |
# if batch_size==1: | |
# # x_i=x_i.view(height*width) | |
# # y_i=y_i.view(height*width) | |
# # z_i=z_i.view(height*width) | |
# x_i=x_i.view(height*width).long() | |
# y_i=y_i.view(height*width).long() | |
# z_i=z_i.view(height*width).long() | |
# # x_i=x_i.view(1, height*width) | |
# # y_i=y_i.view(1, height*width) | |
# # z_i=z_i.view(1, height*width) | |
# # 2维tensor,[1, 2160*3840] | |
# # xyz_i=torch.cat([x_i,y_i,z_i],dim=0) | |
# # # xyz_i 2维tensor,[3, 2160*3840] | |
# # print("xyz_i.size()",xyz_i.size()) | |
# else: | |
# print("error:batch size must be 1") | |
# for i in range(height*width): | |
# h_index=int(i/width) | |
# w_index=int(i%width) | |
# # print(h_index) | |
# # print(w_index) | |
# # print(x_i.size()) | |
# # print(batch_size) | |
# # print(output.size()) | |
# # print(output[0,0,h_index,w_index]) | |
# if(i%10000==0): | |
# print(i) | |
# output[batch_size-1,0,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],0] | |
# output[batch_size-1,1,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],1] | |
# output[batch_size-1,2,h_index,w_index]=LUT[x_i[i],y_i[i],z_i[i],2] | |
# # x_i=x_i.view(batch_size,height*width) | |
# # y_i=y_i.view(batch_size,height*width) | |
# # z_i=z_i.view(batch_size,height*width) | |
# # 1,2160*3840 | |
# return output | |
# def forward(self, LUT, img_input): | |
# assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]' | |
# # N=batch_size | |
# #img_input.size()=[1,3,2160,3840]\ | |
# # LUT.size()=[3,33,33,33] | |
# assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)' | |
# batch_size, num_chans, height, width = img_input.shape | |
# dim = LUT.shape[1] # M | |
# img_size=img_input.size() | |
# Cmax=255.0 | |
# s=Cmax/dim | |
# r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1) | |
# # 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分 | |
# #r,g,b.size()=[1,1,2160,3840] | |
# # r=img_input[:,0,:,:] | |
# # g=img_input[:,1,:,:] | |
# # b=img_input[:,2,:,:] | |
# x=r/s | |
# y=g/s | |
# z=b/s | |
# # tmptmp=self.test(LUT,img_input) | |
# # x,y,z.size=[1,1,,2160,3840] | |
# # x_0,y_0,z_0.size=[1,1,,2160,3840] | |
# # x_1, y_1, z_1.size=[1,1,,2160,3840] | |
# x_0,y_0,z_0=x.floor(),y.floor(),z.floor() | |
# x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0 | |
# u, v, w = x-x_0, y-y_0, z-z_0 | |
# # u,v,w.size=[1,1,2160,3840] | |
# # print("x_0.size",x_0.size()) | |
# c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1)) | |
# print(c_000.size()) | |
# # x_i是顶点,大小为[1,1,2160,3840] | |
# # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840] | |
# c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1)) | |
# c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1)) | |
# c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1)) | |
# c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1)) | |
# c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1)) | |
# c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1)) | |
# c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1)) | |
# # c_000 = self.gen_Cout_ijk(LUT,x_0,y_0,z_0) | |
# # # x_i是顶点,大小为[1,1,2160,3840] | |
# # # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840] | |
# # c_100 = self.gen_Cout_ijk(LUT,x_1,y_0,z_0) | |
# # c_010 = self.gen_Cout_ijk(LUT,x_0,y_1,z_0) | |
# # c_110 = self.gen_Cout_ijk(LUT,x_1,y_1,z_0) | |
# # c_001 = self.gen_Cout_ijk(LUT,x_0,y_0,z_1) | |
# # c_101 = self.gen_Cout_ijk(LUT,x_1,y_0,z_1) | |
# # c_011 = self.gen_Cout_ijk(LUT,x_0,y_1,z_1) | |
# # c_111 = self.gen_Cout_ijk(LUT,x_1,y_1,z_1) | |
# c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \ | |
# (1.0-u)*(1.0-v)*(w)*c_001 + \ | |
# (1.0-u)*(v)*(1.0-w)*c_010 + \ | |
# (1.0-u)*(v)*(w)*c_011 + \ | |
# (u)*(1.0-v)*(1.0-w)*c_100 + \ | |
# (u)*(1.0-v)*(w)*c_101 + \ | |
# (u)*(v)*(1.0-w)*c_110 + \ | |
# (u)*(v)*(w)*c_111 | |
# # 广播机制,输出[1,3,2160,3840] | |
# print("c_xyz",c_xyz.size()) | |
# return c_xyz | |
# # id100 = x_0 + 1.0 + y_0 * dim + z_0 * dim * dim | |
# # id010 = x_0 + (y_0 + 1.0) * dim + z_0 * dim * dim | |
# # id110 = x_0 + 1.0 + (y_0 + 1.0) * dim + z_0 * dim * dim | |
# # id001 = x_0 + y_0 * dim + (z_0 + 1.0) * dim * dim | |
# # id101 = x_0 + 1.0 + y_0 * dim + (z_0 + 1.0) * dim * dim | |
# # id011 = x_0 + (y_0 + 1.0) * dim + (z_0 + 1.0) * dim * dim | |
# # id111 = x_0 + 1.0 + (y_0 + 1.0) * dim + (z_0 + 1.0) * dim * dim | |
# # w000 = (1.0-u)*(1-v)*(1-w) | |
# # #大概也许得改成点乘 | |
# # w100 = u*(1-v)*(1-w) | |
# # w010 = (1-u)*v*(1-w) | |
# # w110 = u*v*(1-w) | |
# # w001 = (1-u)*(1-v)*w | |
# # w101 = u*(1-v)*w | |
# # w011 = (1-u)*v*w | |
# # w111 = u*v*w | |
# # output= | |
# # print("v:",x_0,y_0,z_0) | |
# # print("s:",x_0.size(),y_0.size(),z_0.size()) | |
# # u,v,w=u/s,v/s,w/s | |
# # c_000 = self.gen_Cout_ijk(x_0,y_0,z_0) | |
# # c_100 = self.gen_Cout_ijk(x_1,y_0,z_0) | |
# # c_010 = self.gen_Cout_ijk(x_0,y_1,z_0) | |
# # c_110 = self.gen_Cout_ijk(x_1,y_1,z_0) | |
# # c_001 = self.gen_Cout_ijk(x_0,y_0,z_1) | |
# # c_101 = self.gen_Cout_ijk(x_1,y_0,z_1) | |
# # c_011 = self.gen_Cout_ijk(x_0,y_1,z_1) | |
# # c_111 = self.gen_Cout_ijk(x_1,y_1,z_1) | |
# # c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \ | |
# # (1.0-u)*(1.0-v)*(w)*c_001 + \ | |
# # (1.0-u)*(v)*(1.0-w)*c_010 + \ | |
# # (1.0-u)*(v)*(w)*c_011 + \ | |
# # (u)*(1.0-v)*(1.0-w)*c_100 + \ | |
# # (u)*(1.0-v)*(w)*c_101 + \ | |
# # (u)*(v)*(1.0-w)*c_110 + \ | |
# # (u)*(v)*(w)*c_111 | |
# # return c_xyz | |
class Tritri(nn.Module): | |
def __init__(self): | |
super(Tritri, self).__init__() | |
def forward(self,LUT,img): | |
img = (img - .5) * 2. | |
# grid_sample expects NxDxHxWx3 (1x1xHxWx3) | |
img = img.permute(0, 2, 3, 1)[:, None] | |
# add batch dim to LUT | |
LUT = LUT[None] | |
# grid sample | |
result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True) | |
# drop added dimensions and permute back | |
result = result[:, :, 0].permute(0, 2, 3, 1) | |
return result | |
class bing_lut_trilinearInterplt(nn.Module): | |
def __init__(self): | |
super(bing_lut_trilinearInterplt, self).__init__() | |
def test(self,LUT,img_input): | |
# batch_size, num_chans, height, width = img_input.shape | |
# grid_height, grid_width = LUT.shape[1],LUT.shape[2] | |
grid_in=img_input.transpose(1,2).transpose(2,3) | |
# 1 | |
# 原本img_input NCHW,改成 NHWC | |
xy_grid=grid_in[...,0:2] | |
yz_grid=grid_in[...,1:3] | |
# 23 | |
#只取3通道中的第0和第1通道(0:2不含2) | |
# LUT正确版本应该是[3,33,33,33] | |
# 在这里弄错成为[33,33,33,3] | |
input_LUT=LUT[:,:,:,0:1] | |
input_LUT_ori=input_LUT.squeeze(3) | |
# 45 | |
# [3,33,33,33]->[3,33,33] 把dim=3的数据丢掉了 | |
# input_LUT=LUT[:,:,0,:] | |
# input_LUT_ori=input_LUT.squeeze(2) | |
# # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了 | |
input_LUT=input_LUT_ori[0:2,...] | |
input_LUT2=input_LUT_ori[1:,...] | |
input_LUT=input_LUT.unsqueeze(0) | |
input_LUT2=input_LUT2.unsqueeze(0) | |
# 6-9 | |
# 都是[1,2,33,33] | |
# print(input_LUT.size()) | |
# print("dtype:") | |
# print(input_LUT.dtype) | |
# print(input_LUT2.dtype) | |
# print(xy_grid.dtype) | |
# print(yz_grid.dtype) | |
# input_LUT.int() | |
# input_LUT2.int() | |
# xy_grid.int() | |
# yz_grid.int() | |
# # print(grid_in.size()) | |
sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest',align_corners=False) | |
# .view(batch_size, num_chans, num_d, grid_height, grid_width) | |
sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest',align_corners=False) | |
# .view(batch_size, num_chans, num_d, grid_height, grid_width) | |
# 10 | |
res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1) | |
# print(res.size()) | |
return res | |
def forward(self, LUT, img_input): | |
assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]' | |
# N=batch_size | |
#img_input.size()=[1,3,2160,3840]\ | |
# LUT.size()=[3,33,33,33] | |
assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)' | |
# batch_size, num_chans, height, width = img_input.shape | |
dim = LUT.shape[1] # M | |
# img_size=img_input.size() | |
# Cmax=1.00001 | |
Cmax=10 | |
s=Cmax/(dim-1.0) | |
s=torch.Tensor([s]) | |
#谢谢小黄鸭!!#data types int64 and int32 do not match in BroadcastRel | |
r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1) | |
# 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分 | |
#r,g,b.size()=[1,1,2160,3840] | |
# r=img_input[:,0,:,:] | |
# g=img_input[:,1,:,:] | |
# b=img_input[:,2,:,:] | |
s=s.to(r.device) | |
x=r/s | |
y=g/s | |
z=b/s | |
# tmptmp=self.test(LUT,img_input) | |
# x,y,z.size=[1,1,,2160,3840] | |
# x_0,y_0,z_0.size=[1,1,,2160,3840] | |
# x_1, y_1, z_1.size=[1,1,,2160,3840] | |
x_0,y_0,z_0=x.floor(),y.floor(),z.floor() | |
x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0 | |
u, v, w = x-x_0, y-y_0, z-z_0 | |
# u,v,w.size=[1,1,2160,3840] | |
# print("x_0.size",x_0.size()) | |
c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1)) | |
# print(c_000.size()) | |
# x_i是顶点,大小为[1,1,2160,3840] | |
# 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840] | |
c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1)) | |
c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1)) | |
c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1)) | |
c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1)) | |
c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1)) | |
c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1)) | |
c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1)) | |
c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \ | |
(1.0-u)*(1.0-v)*(w)*c_001 + \ | |
(1.0-u)*(v)*(1.0-w)*c_010 + \ | |
(1.0-u)*(v)*(w)*c_011 + \ | |
(u)*(1.0-v)*(1.0-w)*c_100 + \ | |
(u)*(1.0-v)*(w)*c_101 + \ | |
(u)*(v)*(1.0-w)*c_110 + \ | |
(u)*(v)*(w)*c_111 | |
# 广播机制,输出[1,3,2160,3840] | |
print("c_xyz",c_xyz.size()) | |
return c_xyz | |
class bing_lut_trilinearInterplt_backup(nn.Module): | |
def __init__(self): | |
super(bing_lut_trilinearInterplt, self).__init__() | |
def test(self,LUT,img_input): | |
# batch_size, num_chans, height, width = img_input.shape | |
# grid_height, grid_width = LUT.shape[1],LUT.shape[2] | |
grid_in=img_input.transpose(1,2).transpose(2,3) | |
# 1 | |
# 原本img_input NCHW,改成 NHWC | |
xy_grid=grid_in[...,0:2] | |
yz_grid=grid_in[...,1:3] | |
# 23 | |
#只取3通道中的第0和第1通道(0:2不含2) | |
# LUT正确版本应该是[3,33,33,33] | |
# 在这里弄错成为[33,33,33,3] | |
input_LUT=LUT[:,:,:,0:1] | |
input_LUT_ori=input_LUT.squeeze(3) | |
# 45 | |
# [3,33,33,33]->[3,33,33] 把dim=3的数据丢掉了 | |
# input_LUT=LUT[:,:,0,:] | |
# input_LUT_ori=input_LUT.squeeze(2) | |
# # LUT[33,33,33,3]->[33,33,3],把dim=2的数据丢掉了 | |
input_LUT=input_LUT_ori[0:2,...] | |
input_LUT2=input_LUT_ori[1:,...] | |
input_LUT=input_LUT.unsqueeze(0) | |
input_LUT2=input_LUT2.unsqueeze(0) | |
# 6-9 | |
# 都是[1,2,33,33] | |
# print(input_LUT.size()) | |
# print("dtype:") | |
# print(input_LUT.dtype) | |
# print(input_LUT2.dtype) | |
# print(xy_grid.dtype) | |
# print(yz_grid.dtype) | |
# input_LUT.int() | |
# input_LUT2.int() | |
# xy_grid.int() | |
# yz_grid.int() | |
# # print(grid_in.size()) | |
sampled_in_2d = F.grid_sample(input=input_LUT,grid=xy_grid, mode='nearest') | |
# .view(batch_size, num_chans, num_d, grid_height, grid_width) | |
sampled_in_2d_2 = F.grid_sample(input=input_LUT2,grid=yz_grid, mode='nearest') | |
# .view(batch_size, num_chans, num_d, grid_height, grid_width) | |
# 10 | |
res=torch.cat([sampled_in_2d,sampled_in_2d_2[:,1:,:,:]],dim=1) | |
# print(res.size()) | |
return res | |
def forward(self, LUT, img_input): | |
assert img_input.ndimension()==4, 'img_input should be of shape [N,C,H,W]' | |
# N=batch_size | |
#img_input.size()=[1,3,2160,3840]\ | |
# LUT.size()=[3,33,33,33] | |
assert LUT.ndimension()==4, 'LUT should be of shape [C,M,M,M](M=33)' | |
# batch_size, num_chans, height, width = img_input.shape | |
dim = LUT.shape[1] # M | |
# img_size=img_input.size() | |
Cmax=255.0 | |
s=Cmax/dim | |
s=torch.Tensor([s]) | |
#谢谢小黄鸭!!#data types int64 and int32 do not match in BroadcastRel | |
r,g,b=torch.split(img_input,split_size_or_sections=1,dim=1) | |
# 将[1,3,2160,3840]以维度为1切成[1,1,2160,3840]的三部分 | |
#r,g,b.size()=[1,1,2160,3840] | |
# r=img_input[:,0,:,:] | |
# g=img_input[:,1,:,:] | |
# b=img_input[:,2,:,:] | |
x=r/s | |
y=g/s | |
z=b/s | |
# tmptmp=self.test(LUT,img_input) | |
# x,y,z.size=[1,1,,2160,3840] | |
# x_0,y_0,z_0.size=[1,1,,2160,3840] | |
# x_1, y_1, z_1.size=[1,1,,2160,3840] | |
x_0,y_0,z_0=x.floor(),y.floor(),z.floor() | |
x_1, y_1, z_1 = x_0+1.0, y_0+1.0, z_0+1.0 | |
u, v, w = x-x_0, y-y_0, z-z_0 | |
# u,v,w.size=[1,1,2160,3840] | |
# print("x_0.size",x_0.size()) | |
c_000 = self.test(LUT,torch.cat([x_0,y_0,z_0],dim=1)) | |
# print(c_000.size()) | |
# x_i是顶点,大小为[1,1,2160,3840] | |
# 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840] | |
c_100 = self.test(LUT,torch.cat([x_1,y_0,z_0],dim=1)) | |
c_010 = self.test(LUT,torch.cat([x_0,y_1,z_0],dim=1)) | |
c_110 = self.test(LUT,torch.cat([x_1,y_1,z_0],dim=1)) | |
c_001 = self.test(LUT,torch.cat([x_0,y_0,z_1],dim=1)) | |
c_101 = self.test(LUT,torch.cat([x_1,y_0,z_1],dim=1)) | |
c_011 = self.test(LUT,torch.cat([x_0,y_1,z_1],dim=1)) | |
c_111 = self.test(LUT,torch.cat([x_1,y_1,z_1],dim=1)) | |
# c_000 = self.gen_Cout_ijk(LUT,x_0,y_0,z_0) | |
# # x_i是顶点,大小为[1,1,2160,3840] | |
# # 输出c_xxx是对应顶点的LUT的值,大小为[1,3,2160,3840] | |
# c_100 = self.gen_Cout_ijk(LUT,x_1,y_0,z_0) | |
# c_010 = self.gen_Cout_ijk(LUT,x_0,y_1,z_0) | |
# c_110 = self.gen_Cout_ijk(LUT,x_1,y_1,z_0) | |
# c_001 = self.gen_Cout_ijk(LUT,x_0,y_0,z_1) | |
# c_101 = self.gen_Cout_ijk(LUT,x_1,y_0,z_1) | |
# c_011 = self.gen_Cout_ijk(LUT,x_0,y_1,z_1) | |
# c_111 = self.gen_Cout_ijk(LUT,x_1,y_1,z_1) | |
c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \ | |
(1.0-u)*(1.0-v)*(w)*c_001 + \ | |
(1.0-u)*(v)*(1.0-w)*c_010 + \ | |
(1.0-u)*(v)*(w)*c_011 + \ | |
(u)*(1.0-v)*(1.0-w)*c_100 + \ | |
(u)*(1.0-v)*(w)*c_101 + \ | |
(u)*(v)*(1.0-w)*c_110 + \ | |
(u)*(v)*(w)*c_111 | |
# 广播机制,输出[1,3,2160,3840] | |
print("c_xyz",c_xyz.size()) | |
return c_xyz | |
# @staticmethod | |
# def backward(ctx, lut_grad, x_grad): | |
# lut, x, int_package, float_package = ctx.saved_variables | |
# dim, shift, W, H, batch = int_package | |
# dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch) | |
# binsize = float(float_package[0]) | |
# assert 1 == trilinear.backward(x, | |
# x_grad, | |
# lut_grad, | |
# dim, | |
# shift, | |
# binsize, | |
# W, | |
# H, | |
# batch) | |
# return lut_grad, x_grad | |
class Tri(nn.Module): | |
def __init__(self): | |
super(Tri,self).__init__() | |
if __name__=='__main__': | |
# input_features: shape [B, num_channels, depth, height, width] | |
# sampling_grid: shape [B,depth, height, 3] | |
data = torch.rand(1, 32, 16, 128, 128) | |
# data = torch.rand(1, 3, 16, 128, 128) | |
sampling_grid = (torch.rand(1, 256, 256, 3) - 0.5)*2.0 | |
data = data.float().cuda(0) | |
sampling_grid = sampling_grid.float().cuda(0) | |
trilinear_interpolation = TrilinearIntepolation().cuda(0) | |
# LUT.type() torch.cuda.FloatTensor | |
# LUT.size() torch.Size([3, 33, 33, 33]) | |
# img: torch.Size([1, 3, 2160, 3840]) | |
data2 = torch.rand(1, 3,2160,3840) | |
# LUT2 = torch.rand(33,33,33,3) | |
LUT2 = torch.rand(3,33,33,33) | |
trilinear_interpolation2 = bing_lut_trilinearInterplt() | |
t_start = time.time() | |
interp_data2=trilinear_interpolation2(LUT2,data2) | |
# interpolated_data = trilinear_interpolation(data, sampling_grid) | |
# print(interpolated_data.shape) | |
torch.cuda.synchronize() | |
print('time per iteration ', time.time()-t_start) | |
# for i in range(100): | |
# t_start = time.time() | |
# interpolated_data = trilinear_interpolation(data, sampling_grid) | |
# print(interpolated_data.shape) | |
# torch.cuda.synchronize() | |
# print('time per iteration ', time.time()-t_start) | |