Spaces:
Running on Zero
Running on Zero
| """Transformer class.""" | |
| import logging | |
| import math | |
| from collections import OrderedDict | |
| from pathlib import Path | |
| from typing import Literal, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| import yaml | |
| from torch import nn | |
| import sys, os | |
| from .utils import blockwise_causal_norm | |
| logger = logging.getLogger(__name__) | |
| def _pos_embed_fourier1d_init( | |
| cutoff: float = 256, n: int = 32, cutoff_start: float = 1 | |
| ): | |
| return ( | |
| torch.exp(torch.linspace(-math.log(cutoff_start), -math.log(cutoff), n)) | |
| .unsqueeze(0) | |
| .unsqueeze(0) | |
| ) | |
| def _rope_pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32): | |
| # Maximum initial frequency is 1 | |
| return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0) | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| """Rotate pairs of scalars as 2d vectors by pi/2.""" | |
| x = x.unflatten(-1, (-1, 2)) | |
| x1, x2 = x.unbind(dim=-1) | |
| return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) | |
| class RotaryPositionalEncoding(nn.Module): | |
| def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)): | |
| super().__init__() | |
| assert len(cutoffs) == len(n_pos) | |
| if not all(n % 2 == 0 for n in n_pos): | |
| raise ValueError("n_pos must be even") | |
| self._n_dim = len(cutoffs) | |
| self.freqs = nn.ParameterList([ | |
| nn.Parameter(_rope_pos_embed_fourier1d_init(cutoff, n // 2)) | |
| for cutoff, n in zip(cutoffs, n_pos) | |
| ]) | |
| def get_co_si(self, coords: torch.Tensor): | |
| _B, _N, D = coords.shape | |
| assert D == len(self.freqs) | |
| co = torch.cat( | |
| tuple( | |
| torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq)) | |
| for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) | |
| ), | |
| axis=-1, | |
| ) | |
| si = torch.cat( | |
| tuple( | |
| torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq)) | |
| for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) | |
| ), | |
| axis=-1, | |
| ) | |
| return co, si | |
| def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor): | |
| _B, _N, D = coords.shape | |
| _B, _H, _N, _C = q.shape | |
| if D != self._n_dim: | |
| raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}") | |
| co, si = self.get_co_si(coords) | |
| co = co.unsqueeze(1).repeat_interleave(2, dim=-1) | |
| si = si.unsqueeze(1).repeat_interleave(2, dim=-1) | |
| q2 = q * co + _rotate_half(q) * si | |
| k2 = k * co + _rotate_half(k) * si | |
| return q2, k2 | |
| class FeedForward(nn.Module): | |
| def __init__(self, d_model, expand: float = 2, bias: bool = True): | |
| super().__init__() | |
| self.fc1 = nn.Linear(d_model, int(d_model * expand)) | |
| self.fc2 = nn.Linear(int(d_model * expand), d_model, bias=bias) | |
| self.act = nn.GELU() | |
| def forward(self, x): | |
| return self.fc2(self.act(self.fc1(x))) | |
| class PositionalEncoding(nn.Module): | |
| def __init__( | |
| self, | |
| cutoffs: Tuple[float] = (256,), | |
| n_pos: Tuple[int] = (32,), | |
| cutoffs_start=None, | |
| ): | |
| super().__init__() | |
| if cutoffs_start is None: | |
| cutoffs_start = (1,) * len(cutoffs) | |
| assert len(cutoffs) == len(n_pos) | |
| self.freqs = nn.ParameterList([ | |
| nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2)) | |
| for cutoff, n, cutoff_start in zip(cutoffs, n_pos, cutoffs_start) | |
| ]) | |
| def forward(self, coords: torch.Tensor): | |
| _B, _N, D = coords.shape | |
| assert D == len(self.freqs) | |
| embed = torch.cat( | |
| tuple( | |
| torch.cat( | |
| ( | |
| torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq), | |
| torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq), | |
| ), | |
| axis=-1, | |
| ) | |
| / math.sqrt(len(freq)) | |
| for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) | |
| ), | |
| axis=-1, | |
| ) | |
| return embed | |
| def _bin_init_exp(cutoff: float, n: int): | |
| return torch.exp(torch.linspace(0, math.log(cutoff + 1), n)) | |
| def _bin_init_linear(cutoff: float, n: int): | |
| return torch.linspace(-cutoff, cutoff, n) | |
| class RelativePositionalBias(nn.Module): | |
| def __init__( | |
| self, | |
| n_head: int, | |
| cutoff_spatial: float, | |
| cutoff_temporal: float, | |
| n_spatial: int = 32, | |
| n_temporal: int = 16, | |
| ): | |
| super().__init__() | |
| self._spatial_bins = _bin_init_exp(cutoff_spatial, n_spatial) | |
| self._temporal_bins = _bin_init_linear(cutoff_temporal, 2 * n_temporal + 1) | |
| self.register_buffer("spatial_bins", self._spatial_bins) | |
| self.register_buffer("temporal_bins", self._temporal_bins) | |
| self.n_spatial = n_spatial | |
| self.n_head = n_head | |
| self.bias = nn.Parameter( | |
| -0.5 + torch.rand((2 * n_temporal + 1) * n_spatial, n_head) | |
| ) | |
| def forward(self, coords: torch.Tensor): | |
| _B, _N, _D = coords.shape | |
| t = coords[..., 0] | |
| yx = coords[..., 1:] | |
| temporal_dist = t.unsqueeze(-1) - t.unsqueeze(-2) | |
| spatial_dist = torch.cdist(yx, yx) | |
| spatial_idx = torch.bucketize(spatial_dist, self.spatial_bins) | |
| torch.clamp_(spatial_idx, max=len(self.spatial_bins) - 1) | |
| temporal_idx = torch.bucketize(temporal_dist, self.temporal_bins) | |
| torch.clamp_(temporal_idx, max=len(self.temporal_bins) - 1) | |
| idx = spatial_idx.flatten() + temporal_idx.flatten() * self.n_spatial | |
| bias = self.bias.index_select(0, idx).view((*spatial_idx.shape, self.n_head)) | |
| bias = bias.transpose(-1, 1) | |
| return bias | |
| class RelativePositionalAttention(nn.Module): | |
| def __init__( | |
| self, | |
| coord_dim: int, | |
| embed_dim: int, | |
| n_head: int, | |
| cutoff_spatial: float = 256, | |
| cutoff_temporal: float = 16, | |
| n_spatial: int = 32, | |
| n_temporal: int = 16, | |
| dropout: float = 0.0, | |
| mode: Literal["bias", "rope", "none"] = "bias", | |
| attn_dist_mode: str = "v0", | |
| ): | |
| super().__init__() | |
| if not embed_dim % (2 * n_head) == 0: | |
| raise ValueError( | |
| f"embed_dim {embed_dim} must be divisible by 2 times n_head {2 * n_head}" | |
| ) | |
| self.q_pro = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.k_pro = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.v_pro = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.proj = nn.Linear(embed_dim, embed_dim) | |
| self.dropout = dropout | |
| self.n_head = n_head | |
| self.embed_dim = embed_dim | |
| self.cutoff_spatial = cutoff_spatial | |
| self.attn_dist_mode = attn_dist_mode | |
| if mode == "bias" or mode is True: | |
| self.pos_bias = RelativePositionalBias( | |
| n_head=n_head, | |
| cutoff_spatial=cutoff_spatial, | |
| cutoff_temporal=cutoff_temporal, | |
| n_spatial=n_spatial, | |
| n_temporal=n_temporal, | |
| ) | |
| elif mode == "rope": | |
| n_split = 2 * (embed_dim // (2 * (coord_dim + 1) * n_head)) | |
| self.rot_pos_enc = RotaryPositionalEncoding( | |
| cutoffs=((cutoff_temporal,) + (cutoff_spatial,) * coord_dim), | |
| n_pos=(embed_dim // n_head - coord_dim * n_split,) | |
| + (n_split,) * coord_dim, | |
| ) | |
| elif mode == "none": | |
| pass | |
| elif mode is None or mode is False: | |
| logger.warning( | |
| "attn_positional_bias is not set (None or False), no positional bias." | |
| ) | |
| else: | |
| raise ValueError(f"Unknown mode {mode}") | |
| self._mode = mode | |
| def forward( | |
| self, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| coords: torch.Tensor, | |
| padding_mask: torch.Tensor = None, | |
| ): | |
| B, N, D = query.size() | |
| q = self.q_pro(query) | |
| k = self.k_pro(key) | |
| v = self.v_pro(value) | |
| k = k.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) | |
| q = q.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) | |
| v = v.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) | |
| attn_mask = torch.zeros( | |
| (B, self.n_head, N, N), device=query.device, dtype=q.dtype | |
| ) | |
| attn_ignore_val = -1e3 | |
| yx = coords[..., 1:] | |
| spatial_dist = torch.cdist(yx, yx) | |
| spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1) | |
| attn_mask.masked_fill_(spatial_mask, attn_ignore_val) | |
| if coords is not None: | |
| if self._mode == "bias": | |
| attn_mask = attn_mask + self.pos_bias(coords) | |
| elif self._mode == "rope": | |
| q, k = self.rot_pos_enc(q, k, coords) | |
| if self.attn_dist_mode == "v0": | |
| dist = torch.cdist(coords, coords, p=2) | |
| attn_mask += torch.exp(-0.1 * dist.unsqueeze(1)) | |
| elif self.attn_dist_mode == "v1": | |
| attn_mask += torch.exp( | |
| -5 * spatial_dist.unsqueeze(1) / self.cutoff_spatial | |
| ) | |
| else: | |
| raise ValueError(f"Unknown attn_dist_mode {self.attn_dist_mode}") | |
| if padding_mask is not None: | |
| ignore_mask = torch.logical_or( | |
| padding_mask.unsqueeze(1), padding_mask.unsqueeze(2) | |
| ).unsqueeze(1) | |
| attn_mask.masked_fill_(ignore_mask, attn_ignore_val) | |
| y = F.scaled_dot_product_attention( | |
| q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0 | |
| ) | |
| y = y.transpose(1, 2).contiguous().view(B, N, D) | |
| y = self.proj(y) | |
| return y | |
| class EncoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| coord_dim: int = 2, | |
| d_model=256, | |
| num_heads=4, | |
| dropout=0.1, | |
| cutoff_spatial: int = 256, | |
| window: int = 16, | |
| positional_bias: Literal["bias", "rope", "none"] = "bias", | |
| positional_bias_n_spatial: int = 32, | |
| attn_dist_mode: str = "v0", | |
| ): | |
| super().__init__() | |
| self.positional_bias = positional_bias | |
| self.attn = RelativePositionalAttention( | |
| coord_dim, | |
| d_model, | |
| num_heads, | |
| cutoff_spatial=cutoff_spatial, | |
| n_spatial=positional_bias_n_spatial, | |
| cutoff_temporal=window, | |
| n_temporal=window, | |
| dropout=dropout, | |
| mode=positional_bias, | |
| attn_dist_mode=attn_dist_mode, | |
| ) | |
| self.mlp = FeedForward(d_model) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| coords: torch.Tensor, | |
| padding_mask: torch.Tensor = None, | |
| ): | |
| x = self.norm1(x) | |
| # setting coords to None disables positional bias | |
| a = self.attn( | |
| x, | |
| x, | |
| x, | |
| coords=coords if self.positional_bias else None, | |
| padding_mask=padding_mask, | |
| ) | |
| x = x + a | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |
| class DecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| coord_dim: int = 2, | |
| d_model=256, | |
| num_heads=4, | |
| dropout=0.1, | |
| window: int = 16, | |
| cutoff_spatial: int = 256, | |
| positional_bias: Literal["bias", "rope", "none"] = "bias", | |
| positional_bias_n_spatial: int = 32, | |
| attn_dist_mode: str = "v0", | |
| ): | |
| super().__init__() | |
| self.positional_bias = positional_bias | |
| self.attn = RelativePositionalAttention( | |
| coord_dim, | |
| d_model, | |
| num_heads, | |
| cutoff_spatial=cutoff_spatial, | |
| n_spatial=positional_bias_n_spatial, | |
| cutoff_temporal=window, | |
| n_temporal=window, | |
| dropout=dropout, | |
| mode=positional_bias, | |
| attn_dist_mode=attn_dist_mode, | |
| ) | |
| self.mlp = FeedForward(d_model) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| y: torch.Tensor, | |
| coords: torch.Tensor, | |
| padding_mask: torch.Tensor = None, | |
| ): | |
| x = self.norm1(x) | |
| y = self.norm2(y) | |
| # cross attention | |
| # setting coords to None disables positional bias | |
| a = self.attn( | |
| x, | |
| y, | |
| y, | |
| coords=coords if self.positional_bias else None, | |
| padding_mask=padding_mask, | |
| ) | |
| x = x + a | |
| x = x + self.mlp(self.norm3(x)) | |
| return x | |
| class TrackingTransformer(torch.nn.Module): | |
| def __init__( | |
| self, | |
| coord_dim: int = 3, | |
| feat_dim: int = 0, | |
| d_model: int = 128, | |
| nhead: int = 4, | |
| num_encoder_layers: int = 4, | |
| num_decoder_layers: int = 4, | |
| dropout: float = 0.1, | |
| pos_embed_per_dim: int = 32, | |
| feat_embed_per_dim: int = 1, | |
| window: int = 6, | |
| spatial_pos_cutoff: int = 256, | |
| attn_positional_bias: Literal["bias", "rope", "none"] = "rope", | |
| attn_positional_bias_n_spatial: int = 16, | |
| causal_norm: Literal[ | |
| "none", "linear", "softmax", "quiet_softmax" | |
| ] = "quiet_softmax", | |
| attn_dist_mode: str = "v0", | |
| ): | |
| super().__init__() | |
| self.config = dict( | |
| coord_dim=coord_dim, | |
| feat_dim=feat_dim, | |
| pos_embed_per_dim=pos_embed_per_dim, | |
| d_model=d_model, | |
| nhead=nhead, | |
| num_encoder_layers=num_encoder_layers, | |
| num_decoder_layers=num_decoder_layers, | |
| window=window, | |
| dropout=dropout, | |
| attn_positional_bias=attn_positional_bias, | |
| attn_positional_bias_n_spatial=attn_positional_bias_n_spatial, | |
| spatial_pos_cutoff=spatial_pos_cutoff, | |
| feat_embed_per_dim=feat_embed_per_dim, | |
| causal_norm=causal_norm, | |
| attn_dist_mode=attn_dist_mode, | |
| ) | |
| # TODO remove, alredy present in self.config | |
| # self.window = window | |
| # self.feat_dim = feat_dim | |
| # self.coord_dim = coord_dim | |
| self.proj = nn.Linear( | |
| (1 + coord_dim) * pos_embed_per_dim + feat_dim * feat_embed_per_dim, d_model | |
| ) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.encoder = nn.ModuleList([ | |
| EncoderLayer( | |
| coord_dim, | |
| d_model, | |
| nhead, | |
| dropout, | |
| window=window, | |
| cutoff_spatial=spatial_pos_cutoff, | |
| positional_bias=attn_positional_bias, | |
| positional_bias_n_spatial=attn_positional_bias_n_spatial, | |
| attn_dist_mode=attn_dist_mode, | |
| ) | |
| for _ in range(num_encoder_layers) | |
| ]) | |
| self.decoder = nn.ModuleList([ | |
| DecoderLayer( | |
| coord_dim, | |
| d_model, | |
| nhead, | |
| dropout, | |
| window=window, | |
| cutoff_spatial=spatial_pos_cutoff, | |
| positional_bias=attn_positional_bias, | |
| positional_bias_n_spatial=attn_positional_bias_n_spatial, | |
| attn_dist_mode=attn_dist_mode, | |
| ) | |
| for _ in range(num_decoder_layers) | |
| ]) | |
| self.head_x = FeedForward(d_model) | |
| self.head_y = FeedForward(d_model) | |
| if feat_embed_per_dim > 1: | |
| self.feat_embed = PositionalEncoding( | |
| cutoffs=(1000,) * feat_dim, | |
| n_pos=(feat_embed_per_dim,) * feat_dim, | |
| cutoffs_start=(0.01,) * feat_dim, | |
| ) | |
| else: | |
| self.feat_embed = nn.Identity() | |
| self.pos_embed = PositionalEncoding( | |
| cutoffs=(window,) + (spatial_pos_cutoff,) * coord_dim, | |
| n_pos=(pos_embed_per_dim,) * (1 + coord_dim), | |
| ) | |
| # self.pos_embed = NoPositionalEncoding(d=pos_embed_per_dim * (1 + coord_dim)) | |
| # @profile | |
| def forward(self, coords, features=None, padding_mask=None, attn_feat=None): | |
| assert coords.ndim == 3 and coords.shape[-1] in (3, 4) | |
| _B, _N, _D = coords.shape | |
| # disable padded coords (such that it doesnt affect minimum) | |
| if padding_mask is not None: | |
| coords = coords.clone() | |
| coords[padding_mask] = coords.max() | |
| # remove temporal offset | |
| min_time = coords[:, :, :1].min(dim=1, keepdims=True).values | |
| coords = coords - min_time | |
| pos = self.pos_embed(coords) | |
| if features is None or features.numel() == 0: | |
| features = pos | |
| else: | |
| features = self.feat_embed(features) | |
| features = torch.cat((pos, features), axis=-1) | |
| features = self.proj(features) | |
| if attn_feat is not None: | |
| # add attention embedding | |
| features = features + attn_feat | |
| features = self.norm(features) | |
| x = features | |
| # encoder | |
| for enc in self.encoder: | |
| x = enc(x, coords=coords, padding_mask=padding_mask) | |
| y = features | |
| # decoder w cross attention | |
| for dec in self.decoder: | |
| y = dec(y, x, coords=coords, padding_mask=padding_mask) | |
| # y = dec(y, y, coords=coords, padding_mask=padding_mask) | |
| x = self.head_x(x) | |
| y = self.head_y(y) | |
| # outer product is the association matrix (logits) | |
| A = torch.einsum("bnd,bmd->bnm", x, y) | |
| return A | |
| def normalize_output( | |
| self, | |
| A: torch.FloatTensor, | |
| timepoints: torch.LongTensor, | |
| coords: torch.FloatTensor, | |
| ) -> torch.FloatTensor: | |
| """Apply (parental) softmax, or elementwise sigmoid. | |
| Args: | |
| A: Tensor of shape B, N, N | |
| timepoints: Tensor of shape B, N | |
| coords: Tensor of shape B, N, (time + n_spatial) | |
| """ | |
| assert A.ndim == 3 | |
| assert timepoints.ndim == 2 | |
| assert coords.ndim == 3 | |
| assert coords.shape[2] == 1 + self.config["coord_dim"] | |
| # spatial distances | |
| dist = torch.cdist(coords[:, :, 1:], coords[:, :, 1:]) | |
| invalid = dist > self.config["spatial_pos_cutoff"] | |
| if self.config["causal_norm"] == "none": | |
| # Spatially distant entries are set to zero | |
| A = torch.sigmoid(A) | |
| A[invalid] = 0 | |
| else: | |
| return torch.stack([ | |
| blockwise_causal_norm( | |
| _A, _t, mode=self.config["causal_norm"], mask_invalid=_m | |
| ) | |
| for _A, _t, _m in zip(A, timepoints, invalid) | |
| ]) | |
| return A | |
| def save(self, folder): | |
| folder = Path(folder) | |
| folder.mkdir(parents=True, exist_ok=True) | |
| yaml.safe_dump(self.config, open(folder / "config.yaml", "w")) | |
| torch.save(self.state_dict(), folder / "model.pt") | |
| def from_folder( | |
| cls, folder, map_location=None, checkpoint_path: str = "model.pt" | |
| ): | |
| folder = Path(folder) | |
| config = yaml.load(open(folder / "config.yaml"), Loader=yaml.FullLoader) | |
| model = cls(**config) | |
| fpath = folder / checkpoint_path | |
| logger.info(f"Loading model state from {fpath}") | |
| state = torch.load(fpath, map_location=map_location, weights_only=True) | |
| # if state is a checkpoint, we have to extract state_dict | |
| if "state_dict" in state: | |
| state = state["state_dict"] | |
| state = OrderedDict( | |
| (k[6:], v) for k, v in state.items() if k.startswith("model.") | |
| ) | |
| model.load_state_dict(state) | |
| return model | |
| def from_cfg( | |
| cls, cfg_path | |
| ): | |
| cfg_path = Path(cfg_path) | |
| config = yaml.load(open(cfg_path), Loader=yaml.FullLoader) | |
| model = cls(**config) | |
| return model | |