VisionLanguageGroup's picture
clean up
86072ea
"""Transformer class."""
import logging
import math
from collections import OrderedDict
from pathlib import Path
from typing import Literal, Tuple
import torch
import torch.nn.functional as F
import yaml
from torch import nn
import sys, os
from .utils import blockwise_causal_norm
logger = logging.getLogger(__name__)
def _pos_embed_fourier1d_init(
cutoff: float = 256, n: int = 32, cutoff_start: float = 1
):
return (
torch.exp(torch.linspace(-math.log(cutoff_start), -math.log(cutoff), n))
.unsqueeze(0)
.unsqueeze(0)
)
def _rope_pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32):
# Maximum initial frequency is 1
return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate pairs of scalars as 2d vectors by pi/2."""
x = x.unflatten(-1, (-1, 2))
x1, x2 = x.unbind(dim=-1)
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
class RotaryPositionalEncoding(nn.Module):
def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)):
super().__init__()
assert len(cutoffs) == len(n_pos)
if not all(n % 2 == 0 for n in n_pos):
raise ValueError("n_pos must be even")
self._n_dim = len(cutoffs)
self.freqs = nn.ParameterList([
nn.Parameter(_rope_pos_embed_fourier1d_init(cutoff, n // 2))
for cutoff, n in zip(cutoffs, n_pos)
])
def get_co_si(self, coords: torch.Tensor):
_B, _N, D = coords.shape
assert D == len(self.freqs)
co = torch.cat(
tuple(
torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
),
axis=-1,
)
si = torch.cat(
tuple(
torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
),
axis=-1,
)
return co, si
def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor):
_B, _N, D = coords.shape
_B, _H, _N, _C = q.shape
if D != self._n_dim:
raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}")
co, si = self.get_co_si(coords)
co = co.unsqueeze(1).repeat_interleave(2, dim=-1)
si = si.unsqueeze(1).repeat_interleave(2, dim=-1)
q2 = q * co + _rotate_half(q) * si
k2 = k * co + _rotate_half(k) * si
return q2, k2
class FeedForward(nn.Module):
def __init__(self, d_model, expand: float = 2, bias: bool = True):
super().__init__()
self.fc1 = nn.Linear(d_model, int(d_model * expand))
self.fc2 = nn.Linear(int(d_model * expand), d_model, bias=bias)
self.act = nn.GELU()
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class PositionalEncoding(nn.Module):
def __init__(
self,
cutoffs: Tuple[float] = (256,),
n_pos: Tuple[int] = (32,),
cutoffs_start=None,
):
super().__init__()
if cutoffs_start is None:
cutoffs_start = (1,) * len(cutoffs)
assert len(cutoffs) == len(n_pos)
self.freqs = nn.ParameterList([
nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2))
for cutoff, n, cutoff_start in zip(cutoffs, n_pos, cutoffs_start)
])
def forward(self, coords: torch.Tensor):
_B, _N, D = coords.shape
assert D == len(self.freqs)
embed = torch.cat(
tuple(
torch.cat(
(
torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq),
torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq),
),
axis=-1,
)
/ math.sqrt(len(freq))
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
),
axis=-1,
)
return embed
def _bin_init_exp(cutoff: float, n: int):
return torch.exp(torch.linspace(0, math.log(cutoff + 1), n))
def _bin_init_linear(cutoff: float, n: int):
return torch.linspace(-cutoff, cutoff, n)
class RelativePositionalBias(nn.Module):
def __init__(
self,
n_head: int,
cutoff_spatial: float,
cutoff_temporal: float,
n_spatial: int = 32,
n_temporal: int = 16,
):
super().__init__()
self._spatial_bins = _bin_init_exp(cutoff_spatial, n_spatial)
self._temporal_bins = _bin_init_linear(cutoff_temporal, 2 * n_temporal + 1)
self.register_buffer("spatial_bins", self._spatial_bins)
self.register_buffer("temporal_bins", self._temporal_bins)
self.n_spatial = n_spatial
self.n_head = n_head
self.bias = nn.Parameter(
-0.5 + torch.rand((2 * n_temporal + 1) * n_spatial, n_head)
)
def forward(self, coords: torch.Tensor):
_B, _N, _D = coords.shape
t = coords[..., 0]
yx = coords[..., 1:]
temporal_dist = t.unsqueeze(-1) - t.unsqueeze(-2)
spatial_dist = torch.cdist(yx, yx)
spatial_idx = torch.bucketize(spatial_dist, self.spatial_bins)
torch.clamp_(spatial_idx, max=len(self.spatial_bins) - 1)
temporal_idx = torch.bucketize(temporal_dist, self.temporal_bins)
torch.clamp_(temporal_idx, max=len(self.temporal_bins) - 1)
idx = spatial_idx.flatten() + temporal_idx.flatten() * self.n_spatial
bias = self.bias.index_select(0, idx).view((*spatial_idx.shape, self.n_head))
bias = bias.transpose(-1, 1)
return bias
class RelativePositionalAttention(nn.Module):
def __init__(
self,
coord_dim: int,
embed_dim: int,
n_head: int,
cutoff_spatial: float = 256,
cutoff_temporal: float = 16,
n_spatial: int = 32,
n_temporal: int = 16,
dropout: float = 0.0,
mode: Literal["bias", "rope", "none"] = "bias",
attn_dist_mode: str = "v0",
):
super().__init__()
if not embed_dim % (2 * n_head) == 0:
raise ValueError(
f"embed_dim {embed_dim} must be divisible by 2 times n_head {2 * n_head}"
)
self.q_pro = nn.Linear(embed_dim, embed_dim, bias=True)
self.k_pro = nn.Linear(embed_dim, embed_dim, bias=True)
self.v_pro = nn.Linear(embed_dim, embed_dim, bias=True)
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = dropout
self.n_head = n_head
self.embed_dim = embed_dim
self.cutoff_spatial = cutoff_spatial
self.attn_dist_mode = attn_dist_mode
if mode == "bias" or mode is True:
self.pos_bias = RelativePositionalBias(
n_head=n_head,
cutoff_spatial=cutoff_spatial,
cutoff_temporal=cutoff_temporal,
n_spatial=n_spatial,
n_temporal=n_temporal,
)
elif mode == "rope":
n_split = 2 * (embed_dim // (2 * (coord_dim + 1) * n_head))
self.rot_pos_enc = RotaryPositionalEncoding(
cutoffs=((cutoff_temporal,) + (cutoff_spatial,) * coord_dim),
n_pos=(embed_dim // n_head - coord_dim * n_split,)
+ (n_split,) * coord_dim,
)
elif mode == "none":
pass
elif mode is None or mode is False:
logger.warning(
"attn_positional_bias is not set (None or False), no positional bias."
)
else:
raise ValueError(f"Unknown mode {mode}")
self._mode = mode
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
coords: torch.Tensor,
padding_mask: torch.Tensor = None,
):
B, N, D = query.size()
q = self.q_pro(query)
k = self.k_pro(key)
v = self.v_pro(value)
k = k.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
q = q.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
v = v.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
attn_mask = torch.zeros(
(B, self.n_head, N, N), device=query.device, dtype=q.dtype
)
attn_ignore_val = -1e3
yx = coords[..., 1:]
spatial_dist = torch.cdist(yx, yx)
spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1)
attn_mask.masked_fill_(spatial_mask, attn_ignore_val)
if coords is not None:
if self._mode == "bias":
attn_mask = attn_mask + self.pos_bias(coords)
elif self._mode == "rope":
q, k = self.rot_pos_enc(q, k, coords)
if self.attn_dist_mode == "v0":
dist = torch.cdist(coords, coords, p=2)
attn_mask += torch.exp(-0.1 * dist.unsqueeze(1))
elif self.attn_dist_mode == "v1":
attn_mask += torch.exp(
-5 * spatial_dist.unsqueeze(1) / self.cutoff_spatial
)
else:
raise ValueError(f"Unknown attn_dist_mode {self.attn_dist_mode}")
if padding_mask is not None:
ignore_mask = torch.logical_or(
padding_mask.unsqueeze(1), padding_mask.unsqueeze(2)
).unsqueeze(1)
attn_mask.masked_fill_(ignore_mask, attn_ignore_val)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0
)
y = y.transpose(1, 2).contiguous().view(B, N, D)
y = self.proj(y)
return y
class EncoderLayer(nn.Module):
def __init__(
self,
coord_dim: int = 2,
d_model=256,
num_heads=4,
dropout=0.1,
cutoff_spatial: int = 256,
window: int = 16,
positional_bias: Literal["bias", "rope", "none"] = "bias",
positional_bias_n_spatial: int = 32,
attn_dist_mode: str = "v0",
):
super().__init__()
self.positional_bias = positional_bias
self.attn = RelativePositionalAttention(
coord_dim,
d_model,
num_heads,
cutoff_spatial=cutoff_spatial,
n_spatial=positional_bias_n_spatial,
cutoff_temporal=window,
n_temporal=window,
dropout=dropout,
mode=positional_bias,
attn_dist_mode=attn_dist_mode,
)
self.mlp = FeedForward(d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(
self,
x: torch.Tensor,
coords: torch.Tensor,
padding_mask: torch.Tensor = None,
):
x = self.norm1(x)
# setting coords to None disables positional bias
a = self.attn(
x,
x,
x,
coords=coords if self.positional_bias else None,
padding_mask=padding_mask,
)
x = x + a
x = x + self.mlp(self.norm2(x))
return x
class DecoderLayer(nn.Module):
def __init__(
self,
coord_dim: int = 2,
d_model=256,
num_heads=4,
dropout=0.1,
window: int = 16,
cutoff_spatial: int = 256,
positional_bias: Literal["bias", "rope", "none"] = "bias",
positional_bias_n_spatial: int = 32,
attn_dist_mode: str = "v0",
):
super().__init__()
self.positional_bias = positional_bias
self.attn = RelativePositionalAttention(
coord_dim,
d_model,
num_heads,
cutoff_spatial=cutoff_spatial,
n_spatial=positional_bias_n_spatial,
cutoff_temporal=window,
n_temporal=window,
dropout=dropout,
mode=positional_bias,
attn_dist_mode=attn_dist_mode,
)
self.mlp = FeedForward(d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(
self,
x: torch.Tensor,
y: torch.Tensor,
coords: torch.Tensor,
padding_mask: torch.Tensor = None,
):
x = self.norm1(x)
y = self.norm2(y)
# cross attention
# setting coords to None disables positional bias
a = self.attn(
x,
y,
y,
coords=coords if self.positional_bias else None,
padding_mask=padding_mask,
)
x = x + a
x = x + self.mlp(self.norm3(x))
return x
class TrackingTransformer(torch.nn.Module):
def __init__(
self,
coord_dim: int = 3,
feat_dim: int = 0,
d_model: int = 128,
nhead: int = 4,
num_encoder_layers: int = 4,
num_decoder_layers: int = 4,
dropout: float = 0.1,
pos_embed_per_dim: int = 32,
feat_embed_per_dim: int = 1,
window: int = 6,
spatial_pos_cutoff: int = 256,
attn_positional_bias: Literal["bias", "rope", "none"] = "rope",
attn_positional_bias_n_spatial: int = 16,
causal_norm: Literal[
"none", "linear", "softmax", "quiet_softmax"
] = "quiet_softmax",
attn_dist_mode: str = "v0",
):
super().__init__()
self.config = dict(
coord_dim=coord_dim,
feat_dim=feat_dim,
pos_embed_per_dim=pos_embed_per_dim,
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
window=window,
dropout=dropout,
attn_positional_bias=attn_positional_bias,
attn_positional_bias_n_spatial=attn_positional_bias_n_spatial,
spatial_pos_cutoff=spatial_pos_cutoff,
feat_embed_per_dim=feat_embed_per_dim,
causal_norm=causal_norm,
attn_dist_mode=attn_dist_mode,
)
# TODO remove, alredy present in self.config
# self.window = window
# self.feat_dim = feat_dim
# self.coord_dim = coord_dim
self.proj = nn.Linear(
(1 + coord_dim) * pos_embed_per_dim + feat_dim * feat_embed_per_dim, d_model
)
self.norm = nn.LayerNorm(d_model)
self.encoder = nn.ModuleList([
EncoderLayer(
coord_dim,
d_model,
nhead,
dropout,
window=window,
cutoff_spatial=spatial_pos_cutoff,
positional_bias=attn_positional_bias,
positional_bias_n_spatial=attn_positional_bias_n_spatial,
attn_dist_mode=attn_dist_mode,
)
for _ in range(num_encoder_layers)
])
self.decoder = nn.ModuleList([
DecoderLayer(
coord_dim,
d_model,
nhead,
dropout,
window=window,
cutoff_spatial=spatial_pos_cutoff,
positional_bias=attn_positional_bias,
positional_bias_n_spatial=attn_positional_bias_n_spatial,
attn_dist_mode=attn_dist_mode,
)
for _ in range(num_decoder_layers)
])
self.head_x = FeedForward(d_model)
self.head_y = FeedForward(d_model)
if feat_embed_per_dim > 1:
self.feat_embed = PositionalEncoding(
cutoffs=(1000,) * feat_dim,
n_pos=(feat_embed_per_dim,) * feat_dim,
cutoffs_start=(0.01,) * feat_dim,
)
else:
self.feat_embed = nn.Identity()
self.pos_embed = PositionalEncoding(
cutoffs=(window,) + (spatial_pos_cutoff,) * coord_dim,
n_pos=(pos_embed_per_dim,) * (1 + coord_dim),
)
# self.pos_embed = NoPositionalEncoding(d=pos_embed_per_dim * (1 + coord_dim))
# @profile
def forward(self, coords, features=None, padding_mask=None, attn_feat=None):
assert coords.ndim == 3 and coords.shape[-1] in (3, 4)
_B, _N, _D = coords.shape
# disable padded coords (such that it doesnt affect minimum)
if padding_mask is not None:
coords = coords.clone()
coords[padding_mask] = coords.max()
# remove temporal offset
min_time = coords[:, :, :1].min(dim=1, keepdims=True).values
coords = coords - min_time
pos = self.pos_embed(coords)
if features is None or features.numel() == 0:
features = pos
else:
features = self.feat_embed(features)
features = torch.cat((pos, features), axis=-1)
features = self.proj(features)
if attn_feat is not None:
# add attention embedding
features = features + attn_feat
features = self.norm(features)
x = features
# encoder
for enc in self.encoder:
x = enc(x, coords=coords, padding_mask=padding_mask)
y = features
# decoder w cross attention
for dec in self.decoder:
y = dec(y, x, coords=coords, padding_mask=padding_mask)
# y = dec(y, y, coords=coords, padding_mask=padding_mask)
x = self.head_x(x)
y = self.head_y(y)
# outer product is the association matrix (logits)
A = torch.einsum("bnd,bmd->bnm", x, y)
return A
def normalize_output(
self,
A: torch.FloatTensor,
timepoints: torch.LongTensor,
coords: torch.FloatTensor,
) -> torch.FloatTensor:
"""Apply (parental) softmax, or elementwise sigmoid.
Args:
A: Tensor of shape B, N, N
timepoints: Tensor of shape B, N
coords: Tensor of shape B, N, (time + n_spatial)
"""
assert A.ndim == 3
assert timepoints.ndim == 2
assert coords.ndim == 3
assert coords.shape[2] == 1 + self.config["coord_dim"]
# spatial distances
dist = torch.cdist(coords[:, :, 1:], coords[:, :, 1:])
invalid = dist > self.config["spatial_pos_cutoff"]
if self.config["causal_norm"] == "none":
# Spatially distant entries are set to zero
A = torch.sigmoid(A)
A[invalid] = 0
else:
return torch.stack([
blockwise_causal_norm(
_A, _t, mode=self.config["causal_norm"], mask_invalid=_m
)
for _A, _t, _m in zip(A, timepoints, invalid)
])
return A
def save(self, folder):
folder = Path(folder)
folder.mkdir(parents=True, exist_ok=True)
yaml.safe_dump(self.config, open(folder / "config.yaml", "w"))
torch.save(self.state_dict(), folder / "model.pt")
@classmethod
def from_folder(
cls, folder, map_location=None, checkpoint_path: str = "model.pt"
):
folder = Path(folder)
config = yaml.load(open(folder / "config.yaml"), Loader=yaml.FullLoader)
model = cls(**config)
fpath = folder / checkpoint_path
logger.info(f"Loading model state from {fpath}")
state = torch.load(fpath, map_location=map_location, weights_only=True)
# if state is a checkpoint, we have to extract state_dict
if "state_dict" in state:
state = state["state_dict"]
state = OrderedDict(
(k[6:], v) for k, v in state.items() if k.startswith("model.")
)
model.load_state_dict(state)
return model
@classmethod
def from_cfg(
cls, cfg_path
):
cfg_path = Path(cfg_path)
config = yaml.load(open(cfg_path), Loader=yaml.FullLoader)
model = cls(**config)
return model