| """ |
| PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation |
| |
| Official implementation of the paper: |
| "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" |
| by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis |
| Licensed under a modified MIT license |
| """ |
|
|
| |
| |
|
|
| |
| |
|
|
| import math |
| from typing import Any, Optional, Tuple |
|
|
| import numpy as np |
|
|
| import torch |
| from torch import nn |
|
|
| |
| |
| |
| |
|
|
|
|
| def init_t_xy(end_x: int, end_y: int): |
| t = torch.arange(end_x * end_y, dtype=torch.float32) |
| t_x = (t % end_x).float() |
| t_y = torch.div(t, end_x, rounding_mode="floor").float() |
| return t_x, t_y |
|
|
|
|
| def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): |
| freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) |
| freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) |
|
|
| t_x, t_y = init_t_xy(end_x, end_y) |
| freqs_x = torch.outer(t_x, freqs_x) |
| freqs_y = torch.outer(t_y, freqs_y) |
| freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) |
| freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) |
| return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) |
|
|
|
|
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): |
| ndim = x.ndim |
| assert 0 <= 1 < ndim |
| assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) |
| shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] |
| return freqs_cis.view(*shape) |
|
|
|
|
| def apply_rotary_enc( |
| xq: torch.Tensor, |
| xk: torch.Tensor, |
| freqs_cis: torch.Tensor, |
| repeat_freqs_k: bool = False, |
| ): |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| xk_ = ( |
| torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
| if xk.shape[-2] != 0 |
| else None |
| ) |
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_) |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) |
| if xk_ is None: |
| |
| return xq_out.type_as(xq).to(xq.device), xk |
| |
| if repeat_freqs_k: |
| r = xk_.shape[-2] // xq_.shape[-2] |
| if freqs_cis.is_cuda: |
| freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) |
| else: |
| |
| |
| freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) |
| return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) |
|
|