|
|
import torch
|
|
|
from einops import rearrange
|
|
|
|
|
|
def generate_angles_2d(H,W,D, freq=None):
|
|
|
freq = torch.tensor([10000**(-2*i/D) for i in range(int(D/2))]) if freq is None else freq
|
|
|
pos = torch.outer(torch.linspace(-1, 1, steps=H),torch.linspace(-1, 1, steps=W))
|
|
|
freq_tensor = torch.einsum("ij,k->ijk", pos, freq)
|
|
|
return freq_tensor
|
|
|
|
|
|
def apply_angles_2d(x, f):
|
|
|
x_reshaped = rearrange(x, "B h H W (D p) -> B h H W D p", p=2)
|
|
|
real = x_reshaped[..., 0]
|
|
|
imag = x_reshaped[..., 1]
|
|
|
cosines, sines = f.cos(), f.sin()
|
|
|
|
|
|
rot_real = real * cosines - imag * sines
|
|
|
rot_imag = real * sines + imag * cosines
|
|
|
rot_full = torch.concat((rot_real.unsqueeze(-1), rot_imag.unsqueeze(-1)), dim=-1)
|
|
|
return rearrange(rot_full, "B h H W D p -> B h H W (D p)", p=2)
|
|
|
|
|
|
|
|
|
print(apply_angles_2d(torch.randn(1,8,64,64,768), generate_angles_2d(64,64,768)).shape) |