EscherNet / CaPE.py
kxic's picture
Upload folder using huggingface_hub
e371ddd verified
raw
history blame
10.1 kB
import numpy as np
import einops
import torch
from scipy.spatial.transform import Rotation as R
########################## 6DoF CaPE ####################################
class CaPE_6DoF:
def cape_embed(self, f, P):
"""
Apply CaPE on feature.
:param f: feature vector of shape [..., d]
:param P: 4x4 transformation matrix
:return: rotated feature f by pose P: f@P
"""
f = einops.rearrange(f, '... (d k) -> ... d k', k=4)
return einops.rearrange(f@P, '... d k -> ... (d k)', k=4)
def attn_with_CaPE(self, f1, f2, p1, p2):
"""
Do attention dot production with CaPE pose encoding.
# query = cape_embed(query, p_out_inv) # query f_q @ (p_out)^(-T)
# key = cape_embed(key, p_in) # key f_k @ p_in
:param f1: b (t1 l) d
:param f2: b (t2 l) d
:param p1: [b, t, 4, 4]
:param p2: [b, t, 4, 4]
:return: attention score: q@k.T
"""
l = f1.shape[1] // p1.shape[1]
assert f1.shape[1] // p1.shape[1] == f2.shape[1] // p2.shape[1]
p1_invT = einops.repeat(torch.inverse(p1).permute(0, 1, 3, 2), 'b t m n -> b (t l) m n', l=l) # f1 [b, l*t1, d]
query = self.cape_embed(f1, p1_invT) # [b, l*t1, d] query: f1 @ (p1)^(-T), transpose the last two dim
p2_copy = einops.repeat(p2, 'b t m n -> b (t l) m n', l=l) # f2 [b, l*t2, d]
key = self.cape_embed(f2, p2_copy) # [b, l*t2, d] key: f2 @ p2
att = query @ key.permute(0, 2, 1) # [b, l*t1, l*t2] attention: query@key^T
return att
################### 6DoF Verification ###################################
def euler_to_matrix(alpha, beta, gamma, x, y, z):
# radian
r = R.from_euler('xyz', [alpha, beta, gamma], degrees=True)
t = np.array([[x], [y], [z]])
rot_matrix = r.as_matrix()
rot_matrix = np.concatenate([rot_matrix, t], axis=-1)
rot_matrix = np.concatenate([rot_matrix, [[0, 0, 0, 1]]], axis=0)
return rot_matrix
def random_6dof_pose(B, T):
pose_euler = torch.rand([B, T, 6]).numpy() # euler
pose_matrix = []
for b in range(B):
p = []
for t in range(T):
p.append(torch.from_numpy(euler_to_matrix(*pose_euler[b, t])))
pose_matrix.append(torch.stack(p))
pose_matrix = torch.stack(pose_matrix)
return pose_matrix.float()
bs = 6 # batch size
t1 = 3 # num of target views in each batch, can be arbitrary number
t2 = 5 # num of reference views in each batch, can be arbitrary number
l = 10 # len of token
d = 16 # dim of token feature, need to mod 4 in this case
assert d % 4 == 0
# random init query and key
f1 = torch.rand(bs, t1, l, d) # query
f2 = torch.rand(bs, t2, l, d) # key
f1 = einops.rearrange(f1, 'b t l d -> b (t l) d')
f2 = einops.rearrange(f2, 'b t l d -> b (t l) d')
# random init pose p1, p2, delta_p, [bs, t, 4, 4]
p1 = random_6dof_pose(bs, t1) # [bs, t1, 4, 4]
p2 = random_6dof_pose(bs, t2) # [bs, t2, 4, 4]
p_delta = random_6dof_pose(bs, 1) # [bs, 1, 4, 4]
# delta p is identical to p1 and p2 in each batch
p1_delta = einops.repeat(p_delta, 'b 1 m n -> b (1 t) m n', t=t1//1)
p2_delta = einops.repeat(p_delta, 'b 1 m n -> b (1 t) m n', t=t2//1)
# run attention with CaPE 6DoF
cape_6dof = CaPE_6DoF()
# att
att = cape_6dof.attn_with_CaPE(f1, f2, p1, p2)
# att_delta
att_delta = cape_6dof.attn_with_CaPE(f1, f2, p1@p1_delta, p2@p2_delta)
# condition: att score should be the same i.e. non effect from any delta_p
assert torch.allclose(att, att_delta, 1e-3)
print("6DoF CaPE Verified")
########################## 4DoF CaPE ####################################
class CaPE_4DoF:
def rotate_every_two(self, x):
x = einops.rearrange(x, '... (d j) -> ... d j', j=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return einops.rearrange(x, '... d j -> ... (d j)')
def cape(self, x, p):
d, l, n = x.shape[-1], p.shape[-2], p.shape[-1]
assert d % (2 * n) == 0
m = einops.repeat(p, 'b l n -> b l (n k)', k=d // n)
return m
def cape_embed(self, qq, kk, p1, p2):
"""
Embed camera position encoding into attention map
:param qq: query feature map [b, l_q, feature_dim]
:param kk: key feature map [b, l_k, feature_dim]
:param p1: query pose [b, l_q, pose_dim]
:param p2: key pose [b, l_k, pose_dim]
:return: cape embedded attention map [b, l_q, l_k]
"""
assert p1.shape[-1] == p2.shape[-1]
assert qq.shape[-1] == kk.shape[-1]
assert p1.shape[0] == p2.shape[0] == qq.shape[0] == kk.shape[0]
assert p1.shape[1] == qq.shape[1]
assert p2.shape[1] == kk.shape[1]
m1 = self.cape(qq, p1)
m2 = self.cape(kk, p2)
q = (qq * m1.cos()) + (self.rotate_every_two(qq) * m1.sin())
k = (kk * m2.cos()) + (self.rotate_every_two(kk) * m2.sin())
return q, k
def attn_with_CaPE(self, f1, f2, p1, p2):
"""
Do attention dot production with CaPE pose encoding.
# query = cape_embed(query, p_out_inv) # query f_q @ (p_out)^(-T)
# key = cape_embed(key, p_in) # key f_k @ p_in
:param f1: b (t1 l) d
:param f2: b (t2 l) d
:param p1: [b, t, 4]
:param p2: [b, t, 4]
:return: attention score: q@k.T
"""
l = f1.shape[1] // p1.shape[1]
assert f1.shape[1] // p1.shape[1] == f2.shape[1] // p2.shape[1]
p1_reshape = einops.repeat(p1, 'b t m -> b (t l) m', l=l) # f1 [b, l*t1, d]
p2_reshape = einops.repeat(p2, 'b t m -> b (t l) m', l=l) # f1 [b, l*t1, d]
query, key = self.cape_embed(f1, f2, p1_reshape, p2_reshape)
att = query @ key.permute(0, 2, 1) # [b, l*t1, l*t2] attention: query@key^T
return att
################### 4DoF Verification ###################################
def random_4dof_pose(B, T):
pose = torch.zeros([B, T, 4])
pose[:, :, :3] = torch.rand([B, T, 3]) # radian angle
# theta \in [0, pi], azimuth \in [0, 2pi], radius \in [0, pi], 0
pose[:, :, 1] *= (2*torch.pi)
pose[:, :, 0] *= torch.pi
pose[:, :, 2] *= torch.pi
return pose.float()
def look_at(origin, target, up):
forward = (target - origin)
forward = forward / torch.linalg.norm(forward, dim=-1, keepdim=True)
right = torch.linalg.cross(forward, up)
right = right / torch.linalg.norm(right, dim=-1, keepdim=True)
new_up = torch.linalg.cross(forward, right)
new_up = new_up / torch.linalg.norm(new_up, dim=-1, keepdim=True)
rotation_matrix = torch.stack((right, new_up, forward, target), dim=-1)
matrix = torch.cat([rotation_matrix, torch.tensor([[0, 0, 0, 1]]).repeat(rotation_matrix.shape[0],rotation_matrix.shape[1], 1, 1)], dim=-2)
return matrix
def pose_4dof2matrix(pose_4dof):
"""
:param pose_4dof: [b, t, 4]
:return: pose 4x4 matrix: [b, t, 4, 4]
"""
theta = pose_4dof[:, :, 0]
azimuth = pose_4dof[:, :, 1]
radius = pose_4dof[:, :, 2]
xyz = torch.stack([torch.sin(theta) * torch.cos(azimuth), torch.sin(theta) * torch.sin(azimuth), torch.cos(theta)], dim=-1) * radius.unsqueeze(-1)
origin = torch.zeros_like(xyz)
up = torch.zeros_like(xyz)
up[:, :, 2] = 1
pose = look_at(origin, xyz, up)
return pose
def pose_matrix24dof(pose_matrix):
"""
:param pose_matrix: [b, t, 4, 4]
:return: pose_4dof: [b, t, 4] theta, azimuth, radius, 0, looking at origin
"""
xyz = pose_matrix[..., :3, 3]
xy = xyz[..., 0] ** 2 + xyz[..., 1] ** 2
radius = torch.sqrt(xy + xyz[..., 2] ** 2)
theta = torch.arctan2(torch.sqrt(xy), xyz[..., 2]) # for elevation angle defined from Z-axis down
azimuth = torch.arctan2(xyz[..., 1], xyz[..., 0])
pose = torch.stack([theta, azimuth, radius, torch.zeros_like(radius)], dim=-1)
# move to [0, 2pi]
pose %= (2 * torch.pi)
return pose
bs = 6 # batch size
t1 = 3 # num of target views in each batch, can be arbitrary number
t2 = 5 # num of reference views in each batch, can be arbitrary number
l = 10 # len of token
d = 16 # dim of token feature, need to mod 4 in this case
# random init query and key
f1 = torch.rand(bs, t1, l, d) # query
f2 = torch.rand(bs, t2, l, d) # key
f1 = einops.rearrange(f1, 'b t l d -> b (t l) d')
f2 = einops.rearrange(f2, 'b t l d -> b (t l) d')
#random init 4DoF pose [bs, t1, 4], theta, azimuth, radius, 0
p1 = random_4dof_pose(bs, t1) # [bs, t1, 4]
p2 = random_4dof_pose(bs, t2) # [bs, t2, 4]
p1_matrix = pose_4dof2matrix(p1)
p1_4dof = pose_matrix24dof(p1_matrix)
assert torch.allclose(p1, p1_4dof)
p_delta_4dof = random_4dof_pose(bs, 1)
# delta p is identical to p1 and p2 in each batch
p1_delta_4dof = einops.repeat(p_delta_4dof, 'b 1 m -> b (1 t) m', t=t1//1)
p2_delta_4dof = einops.repeat(p_delta_4dof, 'b 1 m -> b (1 t) m', t=t2//1)
# run attention with CaPE 6DoF
cape_4dof = CaPE_4DoF()
# att
att = cape_4dof.attn_with_CaPE(f1, f2, p1, p2)
# att_delta
att_delta = cape_4dof.attn_with_CaPE(f1, f2, p1+p1_delta_4dof, p2+p2_delta_4dof)
# condition: att score should be the same i.e. non effect from any delta_p
assert torch.allclose(att, att_delta, 1e-3)
print("4DoF CaPE Verified")
# print("You should get assertion error because 4DoF CaPE cannot handle 6DoF jitter")
# # att_delta_6dof, it cannot handle 6dof jitter
# p_delta_6dof = random_6dof_pose(bs, 1) # [bs, 1, 4, 4] any delta transformation in 6DoF
# # delta p is identical to p1 and p2 in each batch
# p1_delta_6dof = einops.repeat(p_delta_6dof, 'b 1 m n -> b (1 t) m n', t=t1//1)
# p2_delta_6dof = einops.repeat(p_delta_6dof, 'b 1 m n -> b (1 t) m n', t=t2//1)
# # 4dof pose to 4x4 matrix
# p1_matrix = pose_4dof2matrix(p1)
# p2_matrix = pose_4dof2matrix(p2)
# att_delta_6dof = cape_4dof.attn_with_CaPE(f1, f2, pose_matrix24dof(p1_matrix@p1_delta_6dof), pose_matrix24dof(p2_matrix@p2_delta_6dof))
# # condition: att score should be the same i.e. non effect from any delta_p
# assert torch.allclose(att, att_delta_6dof, 1e-3)