import torch import torch.nn as nn from torch.autograd import Variable import numpy as np # Code adapted from the rotation continuity repo (https://github.com/papagina/RotationContinuity) #T_poses num*3 #r_matrix batch*3*3 def compute_pose_from_rotation_matrix(T_pose, r_matrix): batch=r_matrix.shape[0] joint_num = T_pose.shape[0] r_matrices = r_matrix.view(batch,1, 3,3).expand(batch,joint_num, 3,3).contiguous().view(batch*joint_num,3,3) src_poses = T_pose.view(1,joint_num,3,1).expand(batch,joint_num,3,1).contiguous().view(batch*joint_num,3,1) out_poses = torch.matmul(r_matrices, src_poses) #(batch*joint_num)*3*1 return out_poses.view(batch, joint_num, 3) # batch*n def normalize_vector( v, return_mag =False): batch=v.shape[0] v_mag = torch.sqrt(v.pow(2).sum(1))# batch v_mag = torch.max(v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8]).to(v.device))) v_mag = v_mag.view(batch,1).expand(batch,v.shape[1]) v = v/v_mag if(return_mag==True): return v, v_mag[:,0] else: return v # u, v batch*n def cross_product( u, v): batch = u.shape[0] #print (u.shape) #print (v.shape) i = u[:,1]*v[:,2] - u[:,2]*v[:,1] j = u[:,2]*v[:,0] - u[:,0]*v[:,2] k = u[:,0]*v[:,1] - u[:,1]*v[:,0] out = torch.cat((i.view(batch,1), j.view(batch,1), k.view(batch,1)),1)#batch*3 return out #poses batch*6 #poses def compute_rotation_matrix_from_ortho6d(ortho6d): x_raw = ortho6d[:,0:3]#batch*3 y_raw = ortho6d[:,3:6]#batch*3 x = normalize_vector(x_raw) #batch*3 z = cross_product(x,y_raw) #batch*3 z = normalize_vector(z)#batch*3 y = cross_product(z,x)#batch*3 x = x.view(-1,3,1) y = y.view(-1,3,1) z = z.view(-1,3,1) matrix = torch.cat((x,y,z), 2) #batch*3*3 return matrix #in batch*6 #out batch*5 def stereographic_project(a): dim = a.shape[1] a = normalize_vector(a) out = a[:,0:dim-1]/(1-a[:,dim-1]) return out #in a batch*5, axis int def stereographic_unproject(a, axis=None): """ Inverse of stereographic projection: increases dimension by one. """ batch=a.shape[0] if axis is None: axis = a.shape[1] s2 = torch.pow(a,2).sum(1) #batch ans = torch.autograd.Variable(torch.zeros(batch, a.shape[1]+1).cuda()) #batch*6 unproj = 2*a/(s2+1).view(batch,1).repeat(1,a.shape[1]) #batch*5 if(axis>0): ans[:,:axis] = unproj[:,:axis] #batch*(axis-0) ans[:,axis] = (s2-1)/(s2+1) #batch ans[:,axis+1:] = unproj[:,axis:] #batch*(5-axis) # Note that this is a no-op if the default option (last axis) is used return ans #a batch*5 #out batch*3*3 def compute_rotation_matrix_from_ortho5d(a): batch = a.shape[0] proj_scale_np = np.array([np.sqrt(2)+1, np.sqrt(2)+1, np.sqrt(2)]) #3 proj_scale = torch.autograd.Variable(torch.FloatTensor(proj_scale_np).cuda()).view(1,3).repeat(batch,1) #batch,3 u = stereographic_unproject(a[:, 2:5] * proj_scale, axis=0)#batch*4 norm = torch.sqrt(torch.pow(u[:,1:],2).sum(1)) #batch u = u/ norm.view(batch,1).repeat(1,u.shape[1]) #batch*4 b = torch.cat((a[:,0:2], u),1)#batch*6 matrix = compute_rotation_matrix_from_ortho6d(b) return matrix #quaternion batch*4 def compute_rotation_matrix_from_quaternion( quaternion): batch=quaternion.shape[0] quat = normalize_vector(quaternion).contiguous() qw = quat[...,0].contiguous().view(batch, 1) qx = quat[...,1].contiguous().view(batch, 1) qy = quat[...,2].contiguous().view(batch, 1) qz = quat[...,3].contiguous().view(batch, 1) # Unit quaternion rotation matrices computatation xx = qx*qx yy = qy*qy zz = qz*qz xy = qx*qy xz = qx*qz yz = qy*qz xw = qx*qw yw = qy*qw zw = qz*qw row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 return matrix #axisAngle batch*4 angle, x,y,z def compute_rotation_matrix_from_axisAngle( axisAngle): batch = axisAngle.shape[0] theta = torch.tanh(axisAngle[:,0])*np.pi #[-180, 180] sin = torch.sin(theta*0.5) axis = normalize_vector(axisAngle[:,1:4]) #batch*3 qw = torch.cos(theta*0.5) qx = axis[:,0]*sin qy = axis[:,1]*sin qz = axis[:,2]*sin # Unit quaternion rotation matrices computatation xx = (qx*qx).view(batch,1) yy = (qy*qy).view(batch,1) zz = (qz*qz).view(batch,1) xy = (qx*qy).view(batch,1) xz = (qx*qz).view(batch,1) yz = (qy*qz).view(batch,1) xw = (qx*qw).view(batch,1) yw = (qy*qw).view(batch,1) zw = (qz*qw).view(batch,1) row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 return matrix #axisAngle batch*3 (x,y,z)*theta def compute_rotation_matrix_from_Rodriguez( rod): batch = rod.shape[0] axis, theta = normalize_vector(rod, return_mag=True) sin = torch.sin(theta) qw = torch.cos(theta) qx = axis[:,0]*sin qy = axis[:,1]*sin qz = axis[:,2]*sin # Unit quaternion rotation matrices computatation xx = (qx*qx).view(batch,1) yy = (qy*qy).view(batch,1) zz = (qz*qz).view(batch,1) xy = (qx*qy).view(batch,1) xz = (qx*qz).view(batch,1) yz = (qy*qz).view(batch,1) xw = (qx*qw).view(batch,1) yw = (qy*qw).view(batch,1) zw = (qz*qw).view(batch,1) row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 return matrix #axisAngle batch*3 a,b,c def compute_rotation_matrix_from_hopf( hopf): batch = hopf.shape[0] theta = (torch.tanh(hopf[:,0])+1.0)*np.pi/2.0 #[0, pi] phi = (torch.tanh(hopf[:,1])+1.0)*np.pi #[0,2pi) tao = (torch.tanh(hopf[:,2])+1.0)*np.pi #[0,2pi) qw = torch.cos(theta/2)*torch.cos(tao/2) qx = torch.cos(theta/2)*torch.sin(tao/2) qy = torch.sin(theta/2)*torch.cos(phi+tao/2) qz = torch.sin(theta/2)*torch.sin(phi+tao/2) # Unit quaternion rotation matrices computatation xx = (qx*qx).view(batch,1) yy = (qy*qy).view(batch,1) zz = (qz*qz).view(batch,1) xy = (qx*qy).view(batch,1) xz = (qx*qz).view(batch,1) yz = (qy*qz).view(batch,1) xw = (qx*qw).view(batch,1) yw = (qy*qw).view(batch,1) zw = (qz*qw).view(batch,1) row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 return matrix #euler batch*4 #output cuda batch*3*3 matrices in the rotation order of XZ'Y'' (intrinsic) or YZX (extrinsic) def compute_rotation_matrix_from_euler(euler): batch=euler.shape[0] c1=torch.cos(euler[:,0]).view(batch,1)#batch*1 s1=torch.sin(euler[:,0]).view(batch,1)#batch*1 c2=torch.cos(euler[:,2]).view(batch,1)#batch*1 s2=torch.sin(euler[:,2]).view(batch,1)#batch*1 c3=torch.cos(euler[:,1]).view(batch,1)#batch*1 s3=torch.sin(euler[:,1]).view(batch,1)#batch*1 row1=torch.cat((c2*c3, -s2, c2*s3 ), 1).view(-1,1,3) #batch*1*3 row2=torch.cat((c1*s2*c3+s1*s3, c1*c2, c1*s2*s3-s1*c3), 1).view(-1,1,3) #batch*1*3 row3=torch.cat((s1*s2*c3-c1*s3, s1*c2, s1*s2*s3+c1*c3), 1).view(-1,1,3) #batch*1*3 matrix = torch.cat((row1, row2, row3), 1) #batch*3*3 return matrix #euler_sin_cos batch*6 #output cuda batch*3*3 matrices in the rotation order of XZ'Y'' (intrinsic) or YZX (extrinsic) def compute_rotation_matrix_from_euler_sin_cos(euler_sin_cos): batch=euler_sin_cos.shape[0] s1 = euler_sin_cos[:,0].view(batch,1) c1 = euler_sin_cos[:,1].view(batch,1) s2 = euler_sin_cos[:,2].view(batch,1) c2 = euler_sin_cos[:,3].view(batch,1) s3 = euler_sin_cos[:,4].view(batch,1) c3 = euler_sin_cos[:,5].view(batch,1) row1=torch.cat((c2*c3, -s2, c2*s3 ), 1).view(-1,1,3) #batch*1*3 row2=torch.cat((c1*s2*c3+s1*s3, c1*c2, c1*s2*s3-s1*c3), 1).view(-1,1,3) #batch*1*3 row3=torch.cat((s1*s2*c3-c1*s3, s1*c2, s1*s2*s3+c1*c3), 1).view(-1,1,3) #batch*1*3 matrix = torch.cat((row1, row2, row3), 1) #batch*3*3 return matrix #matrices batch*3*3 #both matrix are orthogonal rotation matrices #out theta between 0 to 180 degree batch def compute_geodesic_distance_from_two_matrices(m1, m2): batch=m1.shape[0] m = torch.bmm(m1, m2.transpose(1,2)) #batch*3*3 cos = ( m[:,0,0] + m[:,1,1] + m[:,2,2] - 1 )/2 cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda()) ) cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda())*-1 ) theta = torch.acos(cos) #theta = torch.min(theta, 2*np.pi - theta) return theta #matrices batch*3*3 #both matrix are orthogonal rotation matrices #out theta between 0 to 180 degree batch def compute_angle_from_r_matrices(m): batch=m.shape[0] cos = ( m[:,0,0] + m[:,1,1] + m[:,2,2] - 1 )/2 cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda()) ) cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda())*-1 ) theta = torch.acos(cos) return theta def get_sampled_rotation_matrices_by_quat(batch): #quat = torch.autograd.Variable(torch.rand(batch,4).cuda()) quat = torch.autograd.Variable(torch.randn(batch, 4).cuda()) matrix = compute_rotation_matrix_from_quaternion(quat) return matrix def get_sampled_rotation_matrices_by_hpof(batch): theta = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,1, batch)*np.pi).cuda()) #[0, pi] phi = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,2,batch)*np.pi).cuda()) #[0,2pi) tao = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,2,batch)*np.pi).cuda()) #[0,2pi) qw = torch.cos(theta/2)*torch.cos(tao/2) qx = torch.cos(theta/2)*torch.sin(tao/2) qy = torch.sin(theta/2)*torch.cos(phi+tao/2) qz = torch.sin(theta/2)*torch.sin(phi+tao/2) # Unit quaternion rotation matrices computatation xx = (qx*qx).view(batch,1) yy = (qy*qy).view(batch,1) zz = (qz*qz).view(batch,1) xy = (qx*qy).view(batch,1) xz = (qx*qz).view(batch,1) yz = (qy*qz).view(batch,1) xw = (qx*qw).view(batch,1) yw = (qy*qw).view(batch,1) zw = (qz*qw).view(batch,1) row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 return matrix #axisAngle batch*4 angle, x,y,z def get_sampled_rotation_matrices_by_axisAngle( batch, return_quaternion=False): theta = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(-1,1, batch)*np.pi).cuda()) #[0, pi] #[-180, 180] sin = torch.sin(theta) axis = torch.autograd.Variable(torch.randn(batch, 3).cuda()) axis = normalize_vector(axis) #batch*3 qw = torch.cos(theta) qx = axis[:,0]*sin qy = axis[:,1]*sin qz = axis[:,2]*sin quaternion = torch.cat((qw.view(batch,1), qx.view(batch,1), qy.view(batch,1), qz.view(batch,1)), 1 ) # Unit quaternion rotation matrices computatation xx = (qx*qx).view(batch,1) yy = (qy*qy).view(batch,1) zz = (qz*qz).view(batch,1) xy = (qx*qy).view(batch,1) xz = (qx*qz).view(batch,1) yz = (qy*qz).view(batch,1) xw = (qx*qw).view(batch,1) yw = (qy*qw).view(batch,1) zw = (qz*qw).view(batch,1) row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 if(return_quaternion==True): return matrix, quaternion else: return matrix