Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,051 Bytes
e371ddd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
import numpy as np
import einops
import torch
from scipy.spatial.transform import Rotation as R
########################## 6DoF CaPE ####################################
class CaPE_6DoF:
def cape_embed(self, f, P):
"""
Apply CaPE on feature.
:param f: feature vector of shape [..., d]
:param P: 4x4 transformation matrix
:return: rotated feature f by pose P: f@P
"""
f = einops.rearrange(f, '... (d k) -> ... d k', k=4)
return einops.rearrange(f@P, '... d k -> ... (d k)', k=4)
def attn_with_CaPE(self, f1, f2, p1, p2):
"""
Do attention dot production with CaPE pose encoding.
# query = cape_embed(query, p_out_inv) # query f_q @ (p_out)^(-T)
# key = cape_embed(key, p_in) # key f_k @ p_in
:param f1: b (t1 l) d
:param f2: b (t2 l) d
:param p1: [b, t, 4, 4]
:param p2: [b, t, 4, 4]
:return: attention score: q@k.T
"""
l = f1.shape[1] // p1.shape[1]
assert f1.shape[1] // p1.shape[1] == f2.shape[1] // p2.shape[1]
p1_invT = einops.repeat(torch.inverse(p1).permute(0, 1, 3, 2), 'b t m n -> b (t l) m n', l=l) # f1 [b, l*t1, d]
query = self.cape_embed(f1, p1_invT) # [b, l*t1, d] query: f1 @ (p1)^(-T), transpose the last two dim
p2_copy = einops.repeat(p2, 'b t m n -> b (t l) m n', l=l) # f2 [b, l*t2, d]
key = self.cape_embed(f2, p2_copy) # [b, l*t2, d] key: f2 @ p2
att = query @ key.permute(0, 2, 1) # [b, l*t1, l*t2] attention: query@key^T
return att
################### 6DoF Verification ###################################
def euler_to_matrix(alpha, beta, gamma, x, y, z):
# radian
r = R.from_euler('xyz', [alpha, beta, gamma], degrees=True)
t = np.array([[x], [y], [z]])
rot_matrix = r.as_matrix()
rot_matrix = np.concatenate([rot_matrix, t], axis=-1)
rot_matrix = np.concatenate([rot_matrix, [[0, 0, 0, 1]]], axis=0)
return rot_matrix
def random_6dof_pose(B, T):
pose_euler = torch.rand([B, T, 6]).numpy() # euler
pose_matrix = []
for b in range(B):
p = []
for t in range(T):
p.append(torch.from_numpy(euler_to_matrix(*pose_euler[b, t])))
pose_matrix.append(torch.stack(p))
pose_matrix = torch.stack(pose_matrix)
return pose_matrix.float()
bs = 6 # batch size
t1 = 3 # num of target views in each batch, can be arbitrary number
t2 = 5 # num of reference views in each batch, can be arbitrary number
l = 10 # len of token
d = 16 # dim of token feature, need to mod 4 in this case
assert d % 4 == 0
# random init query and key
f1 = torch.rand(bs, t1, l, d) # query
f2 = torch.rand(bs, t2, l, d) # key
f1 = einops.rearrange(f1, 'b t l d -> b (t l) d')
f2 = einops.rearrange(f2, 'b t l d -> b (t l) d')
# random init pose p1, p2, delta_p, [bs, t, 4, 4]
p1 = random_6dof_pose(bs, t1) # [bs, t1, 4, 4]
p2 = random_6dof_pose(bs, t2) # [bs, t2, 4, 4]
p_delta = random_6dof_pose(bs, 1) # [bs, 1, 4, 4]
# delta p is identical to p1 and p2 in each batch
p1_delta = einops.repeat(p_delta, 'b 1 m n -> b (1 t) m n', t=t1//1)
p2_delta = einops.repeat(p_delta, 'b 1 m n -> b (1 t) m n', t=t2//1)
# run attention with CaPE 6DoF
cape_6dof = CaPE_6DoF()
# att
att = cape_6dof.attn_with_CaPE(f1, f2, p1, p2)
# att_delta
att_delta = cape_6dof.attn_with_CaPE(f1, f2, p1@p1_delta, p2@p2_delta)
# condition: att score should be the same i.e. non effect from any delta_p
assert torch.allclose(att, att_delta, 1e-3)
print("6DoF CaPE Verified")
########################## 4DoF CaPE ####################################
class CaPE_4DoF:
def rotate_every_two(self, x):
x = einops.rearrange(x, '... (d j) -> ... d j', j=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return einops.rearrange(x, '... d j -> ... (d j)')
def cape(self, x, p):
d, l, n = x.shape[-1], p.shape[-2], p.shape[-1]
assert d % (2 * n) == 0
m = einops.repeat(p, 'b l n -> b l (n k)', k=d // n)
return m
def cape_embed(self, qq, kk, p1, p2):
"""
Embed camera position encoding into attention map
:param qq: query feature map [b, l_q, feature_dim]
:param kk: key feature map [b, l_k, feature_dim]
:param p1: query pose [b, l_q, pose_dim]
:param p2: key pose [b, l_k, pose_dim]
:return: cape embedded attention map [b, l_q, l_k]
"""
assert p1.shape[-1] == p2.shape[-1]
assert qq.shape[-1] == kk.shape[-1]
assert p1.shape[0] == p2.shape[0] == qq.shape[0] == kk.shape[0]
assert p1.shape[1] == qq.shape[1]
assert p2.shape[1] == kk.shape[1]
m1 = self.cape(qq, p1)
m2 = self.cape(kk, p2)
q = (qq * m1.cos()) + (self.rotate_every_two(qq) * m1.sin())
k = (kk * m2.cos()) + (self.rotate_every_two(kk) * m2.sin())
return q, k
def attn_with_CaPE(self, f1, f2, p1, p2):
"""
Do attention dot production with CaPE pose encoding.
# query = cape_embed(query, p_out_inv) # query f_q @ (p_out)^(-T)
# key = cape_embed(key, p_in) # key f_k @ p_in
:param f1: b (t1 l) d
:param f2: b (t2 l) d
:param p1: [b, t, 4]
:param p2: [b, t, 4]
:return: attention score: q@k.T
"""
l = f1.shape[1] // p1.shape[1]
assert f1.shape[1] // p1.shape[1] == f2.shape[1] // p2.shape[1]
p1_reshape = einops.repeat(p1, 'b t m -> b (t l) m', l=l) # f1 [b, l*t1, d]
p2_reshape = einops.repeat(p2, 'b t m -> b (t l) m', l=l) # f1 [b, l*t1, d]
query, key = self.cape_embed(f1, f2, p1_reshape, p2_reshape)
att = query @ key.permute(0, 2, 1) # [b, l*t1, l*t2] attention: query@key^T
return att
################### 4DoF Verification ###################################
def random_4dof_pose(B, T):
pose = torch.zeros([B, T, 4])
pose[:, :, :3] = torch.rand([B, T, 3]) # radian angle
# theta \in [0, pi], azimuth \in [0, 2pi], radius \in [0, pi], 0
pose[:, :, 1] *= (2*torch.pi)
pose[:, :, 0] *= torch.pi
pose[:, :, 2] *= torch.pi
return pose.float()
def look_at(origin, target, up):
forward = (target - origin)
forward = forward / torch.linalg.norm(forward, dim=-1, keepdim=True)
right = torch.linalg.cross(forward, up)
right = right / torch.linalg.norm(right, dim=-1, keepdim=True)
new_up = torch.linalg.cross(forward, right)
new_up = new_up / torch.linalg.norm(new_up, dim=-1, keepdim=True)
rotation_matrix = torch.stack((right, new_up, forward, target), dim=-1)
matrix = torch.cat([rotation_matrix, torch.tensor([[0, 0, 0, 1]]).repeat(rotation_matrix.shape[0],rotation_matrix.shape[1], 1, 1)], dim=-2)
return matrix
def pose_4dof2matrix(pose_4dof):
"""
:param pose_4dof: [b, t, 4]
:return: pose 4x4 matrix: [b, t, 4, 4]
"""
theta = pose_4dof[:, :, 0]
azimuth = pose_4dof[:, :, 1]
radius = pose_4dof[:, :, 2]
xyz = torch.stack([torch.sin(theta) * torch.cos(azimuth), torch.sin(theta) * torch.sin(azimuth), torch.cos(theta)], dim=-1) * radius.unsqueeze(-1)
origin = torch.zeros_like(xyz)
up = torch.zeros_like(xyz)
up[:, :, 2] = 1
pose = look_at(origin, xyz, up)
return pose
def pose_matrix24dof(pose_matrix):
"""
:param pose_matrix: [b, t, 4, 4]
:return: pose_4dof: [b, t, 4] theta, azimuth, radius, 0, looking at origin
"""
xyz = pose_matrix[..., :3, 3]
xy = xyz[..., 0] ** 2 + xyz[..., 1] ** 2
radius = torch.sqrt(xy + xyz[..., 2] ** 2)
theta = torch.arctan2(torch.sqrt(xy), xyz[..., 2]) # for elevation angle defined from Z-axis down
azimuth = torch.arctan2(xyz[..., 1], xyz[..., 0])
pose = torch.stack([theta, azimuth, radius, torch.zeros_like(radius)], dim=-1)
# move to [0, 2pi]
pose %= (2 * torch.pi)
return pose
bs = 6 # batch size
t1 = 3 # num of target views in each batch, can be arbitrary number
t2 = 5 # num of reference views in each batch, can be arbitrary number
l = 10 # len of token
d = 16 # dim of token feature, need to mod 4 in this case
# random init query and key
f1 = torch.rand(bs, t1, l, d) # query
f2 = torch.rand(bs, t2, l, d) # key
f1 = einops.rearrange(f1, 'b t l d -> b (t l) d')
f2 = einops.rearrange(f2, 'b t l d -> b (t l) d')
#random init 4DoF pose [bs, t1, 4], theta, azimuth, radius, 0
p1 = random_4dof_pose(bs, t1) # [bs, t1, 4]
p2 = random_4dof_pose(bs, t2) # [bs, t2, 4]
p1_matrix = pose_4dof2matrix(p1)
p1_4dof = pose_matrix24dof(p1_matrix)
assert torch.allclose(p1, p1_4dof)
p_delta_4dof = random_4dof_pose(bs, 1)
# delta p is identical to p1 and p2 in each batch
p1_delta_4dof = einops.repeat(p_delta_4dof, 'b 1 m -> b (1 t) m', t=t1//1)
p2_delta_4dof = einops.repeat(p_delta_4dof, 'b 1 m -> b (1 t) m', t=t2//1)
# run attention with CaPE 6DoF
cape_4dof = CaPE_4DoF()
# att
att = cape_4dof.attn_with_CaPE(f1, f2, p1, p2)
# att_delta
att_delta = cape_4dof.attn_with_CaPE(f1, f2, p1+p1_delta_4dof, p2+p2_delta_4dof)
# condition: att score should be the same i.e. non effect from any delta_p
assert torch.allclose(att, att_delta, 1e-3)
print("4DoF CaPE Verified")
# print("You should get assertion error because 4DoF CaPE cannot handle 6DoF jitter")
# # att_delta_6dof, it cannot handle 6dof jitter
# p_delta_6dof = random_6dof_pose(bs, 1) # [bs, 1, 4, 4] any delta transformation in 6DoF
# # delta p is identical to p1 and p2 in each batch
# p1_delta_6dof = einops.repeat(p_delta_6dof, 'b 1 m n -> b (1 t) m n', t=t1//1)
# p2_delta_6dof = einops.repeat(p_delta_6dof, 'b 1 m n -> b (1 t) m n', t=t2//1)
# # 4dof pose to 4x4 matrix
# p1_matrix = pose_4dof2matrix(p1)
# p2_matrix = pose_4dof2matrix(p2)
# att_delta_6dof = cape_4dof.attn_with_CaPE(f1, f2, pose_matrix24dof(p1_matrix@p1_delta_6dof), pose_matrix24dof(p2_matrix@p2_delta_6dof))
# # condition: att score should be the same i.e. non effect from any delta_p
# assert torch.allclose(att, att_delta_6dof, 1e-3)
|