|
from __future__ import annotations |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import PreTrainedModel |
|
from transformers.utils import ModelOutput |
|
|
|
from .configuration_decomposer import DecomposerConfig |
|
|
|
def pairwise_cosine(x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
x : [B,d] or [N,B,d] |
|
returns a square similarity matrix: |
|
[B,B] or [N,B,B] |
|
""" |
|
x = F.normalize(x, p=2, dim=-1) |
|
return torch.matmul(x, x.transpose(-1, -2)) |
|
|
|
def cross_cosine(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
|
""" |
|
a : [M,d] or [N,M,d] |
|
b : [K,d] (reference set - no extra axis) |
|
returns: |
|
[M,K] or [N,M,K] |
|
""" |
|
a_n = F.normalize(a, 2, -1) |
|
b_n = F.normalize(b, 2, -1) |
|
|
|
if a.ndim == 2: |
|
return a_n @ b_n.T |
|
|
|
if a.ndim == 3: |
|
return torch.einsum("n m d , k d -> n m k", a_n, b_n) |
|
|
|
raise ValueError("cross_cosine: unexpected tensor rank.") |
|
|
|
def _drop_diag(M: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Remove the main diagonal per similarity matrix. |
|
works for 2-D [B,B] or 3-D [N,B,B] tensors. |
|
""" |
|
if M.ndim == 2: |
|
n = M.size(0) |
|
return M.masked_select(~torch.eye(n, dtype=torch.bool, device=M.device) |
|
).view(n, n - 1) |
|
|
|
if M.ndim == 3: |
|
n = M.size(1) |
|
mask = torch.eye(n, dtype=torch.bool, device=M.device).unsqueeze(0) |
|
return M.masked_select(~mask).view(M.size(0), n, n - 1) |
|
|
|
raise ValueError("_drop_diag expects 2- or 3-D tensor") |
|
|
|
|
|
def rowwise_pearson(ref: torch.Tensor, |
|
pred: torch.Tensor, |
|
*, |
|
rm_diag: bool = True) -> torch.Tensor: |
|
""" |
|
Pearson row-by-row; supports 2-D or 3-D inputs with identical shape. |
|
returns mean correlation error (0 → perfect). |
|
""" |
|
if rm_diag: |
|
ref = _drop_diag(ref) |
|
pred = _drop_diag(pred) |
|
|
|
ref_z = F.normalize(ref - ref.mean(-1, keepdim=True), p=2, dim=-1) |
|
pred_z = F.normalize(pred - pred.mean(-1, keepdim=True), p=2, dim=-1) |
|
loss = 1 - (ref_z * pred_z).sum(-1).mean(-1) |
|
if loss.ndim==0: |
|
loss = loss.unsqueeze(0) |
|
return loss |
|
|
|
def similarity_mse(ref: torch.Tensor, |
|
pred: torch.Tensor, |
|
*, |
|
rm_diag: bool = True) -> torch.Tensor: |
|
if rm_diag: |
|
ref, pred = _drop_diag(ref), _drop_diag(pred) |
|
|
|
if pred.ndim==2: |
|
loss = F.mse_loss(pred, ref).mean().unsqueeze(0) |
|
elif pred.ndim==3: |
|
loss = F.mse_loss(pred, |
|
ref.expand_as(pred), |
|
reduction="none" |
|
).reshape(pred.size(0), -1).mean(-1) |
|
|
|
return loss |
|
|
|
|
|
def sim_loss(pred: torch.Tensor, |
|
targ: torch.Tensor, |
|
ref: Optional[torch.Tensor], |
|
k_vals: Optional[List[int]], |
|
loss_type: str = "pearson") -> torch.Tensor: |
|
""" |
|
Returns stacked tensor of losses: |
|
len = 1 + len(k_vals) |
|
If `ref` is given we compute cross-similarities pred↔ref / targ↔ref, |
|
otherwise self-similarities pred↔pred / targ↔targ. |
|
""" |
|
|
|
loss_fn = rowwise_pearson if loss_type == "pearson" else similarity_mse |
|
|
|
if ref is None: |
|
p_sim, t_sim = pairwise_cosine(pred), pairwise_cosine(targ) |
|
rm_diag = True |
|
else: |
|
p_sim, t_sim = cross_cosine(pred, ref), cross_cosine(targ, ref) |
|
rm_diag = False |
|
|
|
losses = [loss_fn(t_sim, p_sim, rm_diag=rm_diag)] |
|
|
|
if k_vals: |
|
|
|
ranks = t_sim.argsort(-1, descending=True) |
|
start = 1 if rm_diag else 0 |
|
for k in k_vals: |
|
idx = ranks[..., start:start + k] |
|
t_k = torch.gather(t_sim, -1, idx) |
|
if p_sim.ndim==2: |
|
p_k = torch.gather(p_sim, -1, idx) |
|
elif p_sim.ndim==3: |
|
p_k = torch.gather(p_sim, -1, idx.repeat(p_sim.size(0), 1, 1)) |
|
losses.append(loss_fn(t_k, p_k, rm_diag=False)) |
|
|
|
return torch.stack(losses, 1) |
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, d_in: int, d_out: int): |
|
super().__init__() |
|
self.fc1 = nn.Linear(d_in, d_out * 2) |
|
self.fc2 = nn.Linear(d_out, d_out) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x1, x2 = self.fc1(x).chunk(2, -1) |
|
return self.fc2(F.silu(x1) * x2) |
|
|
|
|
|
class FeedForwardLayer(nn.Module): |
|
def __init__(self, |
|
d_in: int, |
|
d_out: int, |
|
*, |
|
dropout: float = .1, |
|
ln_eps: Optional[float] = 1e-12): |
|
super().__init__() |
|
self.ff = FeedForward(d_in, d_out) |
|
self.skip = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity() |
|
self.drop = nn.Dropout(dropout) |
|
self.norm = nn.LayerNorm(d_out, eps=ln_eps) if ln_eps else nn.Identity() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.norm(self.ff(self.drop(x)) + self.skip(x)) |
|
|
|
class OutputLinear(nn.Module): |
|
def __init__(self, |
|
input_size: int, |
|
n_head_layers: int, |
|
n_output: int, |
|
output_sizes: List[int], |
|
dropout: float=0.1, |
|
ln_eps: Optional[float] = 1e-12): |
|
super().__init__() |
|
self.n_output = n_output |
|
ff_layers = [FeedForwardLayer(input_size, input_size, dropout=dropout, |
|
ln_eps=None if i==n_head_layers-1 else ln_eps) |
|
for i in range(n_head_layers)] |
|
self.ff = nn.Sequential(*ff_layers) |
|
self.layers = nn.ModuleDict({str(d): nn.Linear(input_size, d*n_output) |
|
for d in output_sizes}) |
|
|
|
def forward(self, inputs: torch.Tensor, sizes: List[int]): |
|
inputs = self.ff(inputs) |
|
weights = torch.cat([self.layers[str(i)].weight for i in sizes]) |
|
biases = torch.cat([self.layers[str(i)].bias for i in sizes]) |
|
outputs = F.linear(inputs, weights, biases) |
|
output_dict = {} |
|
current = 0 |
|
|
|
slice_sizes = [d*self.n_output for d in sizes] |
|
for size in slice_sizes: |
|
p = outputs[:, :, current:current+size] |
|
p = p.view(p.size(0), p.size(1), self.n_output, size//self.n_output) |
|
output_dict[size//self.n_output] = p |
|
current += size |
|
return output_dict |
|
|
|
def get_compression_heads(d_in, comp_sizes, n_layers, add_input_identity=False): |
|
compression_heads = nn.ModuleDict({}) |
|
for d in comp_sizes: |
|
enc_layers = [] |
|
for i in range(n_layers): |
|
last = i == n_layers - 1 |
|
enc_layers.append( |
|
FeedForwardLayer( |
|
d_in, |
|
d if last else d_in, |
|
dropout=0.0, |
|
ln_eps=None if last else 1e-12, |
|
) |
|
) |
|
compression_heads[str(d)] = nn.Sequential(*enc_layers) |
|
if add_input_identity: |
|
compression_heads[str(d_in)] = nn.Identity() |
|
|
|
return compression_heads |
|
|
|
|
|
@dataclass |
|
class DecomposerOutput(ModelOutput): |
|
loss: torch.FloatTensor |
|
loss_terms: Optional[Dict[str, torch.Tensor]] = None |
|
decomp: Optional[Dict[int, torch.FloatTensor]] = None |
|
ref_idxs: Optional[torch.LongTensor] = None |
|
|
|
|
|
|
|
class DecomposerModel(PreTrainedModel): |
|
"""Maps an embedding to *n_output* building-block embeddings for every |
|
requested `output_size`. All loops are left intact for clarity.""" |
|
config_class = DecomposerConfig |
|
|
|
|
|
def __init__(self, config: DecomposerConfig): |
|
super().__init__(config) |
|
|
|
|
|
self.compression_heads = get_compression_heads(config.input_size, |
|
config.comp_sizes, |
|
config.n_comp_layers, |
|
add_input_identity=True) |
|
|
|
self.in_proj = nn.ModuleDict({ |
|
str(d): FeedForwardLayer(d, config.shared_dim, |
|
dropout=config.dropout, |
|
ln_eps=config.layer_norm_eps) |
|
for d in config.comp_sizes |
|
}) |
|
|
|
|
|
blk = lambda: FeedForwardLayer(config.shared_dim, |
|
config.shared_dim, |
|
dropout=config.dropout, |
|
ln_eps=config.layer_norm_eps) |
|
self.trunk = nn.Sequential(*[blk() for _ in range(config.n_shared_layers)]) |
|
|
|
|
|
self.out_proj = OutputLinear(self.config.shared_dim, |
|
self.config.n_head_layers, |
|
config.n_output, |
|
config.output_sizes, |
|
config.dropout, |
|
config.layer_norm_eps) |
|
|
|
|
|
self.ref_emb = nn.ModuleDict({ |
|
str(d): nn.Embedding(config.n_refs_total, d) |
|
for d in config.output_sizes if config.n_refs_total |
|
}) |
|
|
|
self.post_init() |
|
|
|
|
|
def compress(self, |
|
inputs: torch.Tensor, |
|
comp_sizes: List[int]): |
|
compressed = {d: self.compression_heads[str(d)](inputs) for d in comp_sizes} |
|
return compressed |
|
|
|
def decompose(self, |
|
inputs: Dict[int, torch.Tensor], |
|
output_sizes: List[int]): |
|
hiddens = [] |
|
for input_size in self.config.comp_sizes: |
|
if input_size not in inputs: |
|
continue |
|
|
|
h = self.in_proj[str(input_size)](inputs[input_size]) |
|
hiddens.append(h) |
|
|
|
hiddens = torch.stack(hiddens, dim=0) |
|
hiddens = self.trunk(hiddens) |
|
|
|
preds = self.out_proj(hiddens, output_sizes) |
|
return preds |
|
|
|
def load_targets(self, |
|
bb1_ids: torch.LongTensor, |
|
bb2_ids: torch.LongTensor): |
|
targets = {} |
|
for size in self.config.output_sizes: |
|
embedding = self.ref_emb[str(size)] |
|
targets[size] = torch.stack([embedding(bb1_ids), embedding(bb2_ids)], dim=1) |
|
return targets |
|
|
|
def compute_loss(self, |
|
inputs: Dict[int, torch.Tensor], |
|
preds: Dict[int, torch.Tensor], |
|
targets: Dict[int, torch.Tensor], |
|
ref_idxs: Optional[torch.LongTensor]=None,): |
|
device = next(iter(preds.values())).device |
|
loss_terms: Dict[str, torch.Tensor] = {} |
|
loss_total = torch.zeros((), device=device) |
|
cfg = self.config |
|
for out_size in cfg.output_sizes: |
|
p = preds[out_size] |
|
t = targets[out_size] |
|
|
|
|
|
if cfg.cosine_weight>0: |
|
cos = 1 - F.cosine_similarity(p, t, dim=-1).view(p.size(0), -1).mean(-1) |
|
loss_total += cfg.cosine_weight * cos.sum() |
|
for i, in_size in enumerate(cfg.comp_sizes): |
|
loss_terms[f"{in_size}->{out_size}_cos"] = cos[i] |
|
|
|
|
|
if cfg.mse_weight>0: |
|
mse = F.mse_loss(p, t.expand_as(p), reduction="none").view(p.size(0), -1).mean(-1) |
|
loss_total += cfg.mse_weight * mse.sum() |
|
for i, in_size in enumerate(cfg.comp_sizes): |
|
loss_terms[f"{in_size}->{out_size}_mse"] = mse[i] |
|
|
|
|
|
if cfg.corr_weight: |
|
flat_p = p.flatten(1, 2) |
|
flat_t = t.flatten(0, 1) |
|
|
|
with torch.no_grad(): |
|
ref = self.ref_emb[str(out_size)](ref_idxs) |
|
|
|
ref_corr = sim_loss(flat_p, flat_t, ref, |
|
cfg.corr_k_vals, cfg.corr_loss_type).mean(-1) |
|
loss_total += cfg.corr_weight * ref_corr.sum() |
|
for i, in_size in enumerate(cfg.comp_sizes): |
|
loss_terms[f"{in_size}->{out_size}_corr_ref"] = ref_corr[i] |
|
|
|
return loss_total, loss_terms |
|
|
|
def forward(self, |
|
embedding: torch.Tensor, |
|
bb1_id: torch.LongTensor, |
|
bb2_id: torch.LongTensor, |
|
*, |
|
ref_idxs: Optional[torch.LongTensor]=None, |
|
return_preds: bool = False, |
|
compute_loss: bool = True, |
|
return_dict: bool = True) -> DecomposerOutput: |
|
|
|
cfg = self.config |
|
device = embedding.device |
|
targets = self.load_targets(bb1_id, bb2_id) |
|
|
|
if cfg.corr_weight and cfg.n_refs_total and ref_idxs is None: |
|
ref_idxs = torch.randint(cfg.n_refs_total, |
|
(cfg.n_refs_batch,), |
|
device=device) |
|
|
|
loss_terms: Dict[str, torch.Tensor] = {} |
|
loss_total = torch.zeros((), device=device) if compute_loss else None |
|
|
|
with torch.no_grad(): |
|
compressed_inputs = self.compress(embedding, cfg.comp_sizes) |
|
|
|
if cfg.input_size in cfg.comp_sizes: |
|
compressed_inputs[cfg.input_size] = embedding |
|
|
|
preds = self.decompose(compressed_inputs, cfg.output_sizes) |
|
|
|
loss_total = None |
|
loss_terms = {} |
|
if compute_loss: |
|
loss_total, loss_terms = self.compute_loss(compressed_inputs, preds, targets, ref_idxs) |
|
|
|
decomp = {k:v.permute(1,0,2,3) for k,v in preds.items()} |
|
|
|
return DecomposerOutput(loss = loss_total, |
|
loss_terms = loss_terms, |
|
decomp = decomp, |
|
ref_idxs = ref_idxs) |
|
|
|
|