Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from src.model.config import ModelConfig | |
| class FutureProposalCandidate: | |
| start: int | |
| end: int | |
| repr: torch.Tensor | |
| score: torch.Tensor | |
| root_token: int | None | |
| class FutureProposalHead(nn.Module): | |
| def __init__(self, cfg: ModelConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| hidden_dim = max(32, int(cfg.anchor_future_proposal_hidden)) | |
| self.score_mlp = nn.Sequential( | |
| nn.Linear(10, hidden_dim), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim, 1), | |
| ) | |
| self.repr_delta = nn.Sequential( | |
| nn.Linear(cfg.d_model * 4, hidden_dim), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim, cfg.d_model), | |
| ) | |
| def _cosine01_tensor(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |
| cosine = F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0), dim=-1).mean() | |
| return (cosine + 1.0) * 0.5 | |
| def _candidate_lengths( | |
| self, | |
| span_len: int, | |
| available: int, | |
| ) -> list[int]: | |
| if available <= 0: | |
| return [] | |
| lengths = { | |
| max(1, span_len // 2), | |
| max(1, span_len), | |
| max(1, min(available, span_len + max(1, span_len // 2))), | |
| max(1, min(available, span_len * 2)), | |
| } | |
| return sorted(length for length in lengths if 1 <= length <= available) | |
| def _search_bounds( | |
| self, | |
| anchor, | |
| seq_len: int, | |
| ) -> tuple[int, int, int]: | |
| span_len = max(int(anchor.end_idx) - int(anchor.start_idx) + 1, 1) | |
| start = min(int(anchor.end_idx) + 1, seq_len) | |
| if start >= seq_len: | |
| return start, start, span_len | |
| base_horizon = max( | |
| int(float(anchor.ttl) * float(self.cfg.anchor_future_proposal_horizon_scale)), | |
| int(span_len * float(self.cfg.anchor_future_proposal_span_scale)), | |
| ) | |
| horizon = min(max(base_horizon, span_len), int(self.cfg.anchor_future_proposal_max_horizon)) | |
| stop = min(seq_len, start + max(horizon, 1)) | |
| return start, stop, span_len | |
| def _subsample_candidates( | |
| self, | |
| candidates: list[FutureProposalCandidate], | |
| ) -> list[FutureProposalCandidate]: | |
| max_windows = max(1, int(self.cfg.anchor_future_proposal_max_windows)) | |
| if len(candidates) <= max_windows: | |
| return candidates | |
| idx = torch.linspace(0, len(candidates) - 1, steps=max_windows).round().long().tolist() | |
| return [candidates[i] for i in idx] | |
| def _build_candidates( | |
| self, | |
| seq_hidden: torch.Tensor, | |
| seq_ids: torch.Tensor | None, | |
| anchor, | |
| ) -> list[FutureProposalCandidate]: | |
| seq_len = seq_hidden.size(0) | |
| start, stop, span_len = self._search_bounds(anchor, seq_len) | |
| if stop <= start: | |
| return [] | |
| available = stop - start | |
| lengths = self._candidate_lengths(span_len, available) | |
| if not lengths: | |
| return [] | |
| anchor_hidden_span = seq_hidden[int(anchor.start_idx): int(anchor.end_idx) + 1] | |
| anchor_delta = ( | |
| anchor_hidden_span[1:] - anchor_hidden_span[:-1] | |
| if anchor_hidden_span.size(0) > 1 | |
| else None | |
| ) | |
| candidates: list[FutureProposalCandidate] = [] | |
| for length in lengths: | |
| max_offset = stop - length + 1 | |
| for offset in range(start, max_offset): | |
| window_hidden = seq_hidden[offset: offset + length] | |
| window_mean = window_hidden.mean(dim=0) | |
| mean_sim = self._cosine01_tensor(anchor.repr, window_mean) | |
| contrast = 1.0 - mean_sim | |
| if anchor_delta is not None and anchor_delta.numel() > 0 and window_hidden.size(0) > 1: | |
| window_delta = window_hidden[1:] - window_hidden[:-1] | |
| transition_sim = self._cosine01_tensor(anchor_delta.mean(dim=0), window_delta.mean(dim=0)) | |
| else: | |
| transition_sim = mean_sim | |
| coherence = ((F.cosine_similarity(window_hidden, window_mean.unsqueeze(0), dim=-1) + 1.0) * 0.5).mean() | |
| tail_hidden = seq_hidden[offset + length: stop] | |
| if tail_hidden.numel() > 0: | |
| tail_support = self._cosine01_tensor(window_mean, tail_hidden.mean(dim=0)) | |
| else: | |
| tail_support = coherence | |
| if seq_ids is None: | |
| token_overlap = seq_hidden.new_tensor(0.0) | |
| root_token = None | |
| else: | |
| anchor_ids = seq_ids[int(anchor.start_idx): int(anchor.end_idx) + 1] | |
| window_ids = seq_ids[offset: offset + length] | |
| anchor_token_set = {int(token) for token in anchor_ids.tolist()} | |
| window_token_set = {int(token) for token in window_ids.tolist()} | |
| token_overlap = seq_hidden.new_tensor( | |
| len(anchor_token_set & window_token_set) / max(len(anchor_token_set), 1) | |
| ) | |
| root_token = int(window_ids[-1].item()) | |
| distance = max(0, offset - int(anchor.end_idx)) | |
| distance_decay = seq_hidden.new_tensor(1.0 / (1.0 + distance / max(float(span_len), 1.0))) | |
| pressure = seq_hidden.new_tensor(float(anchor.contradiction_pressure)) | |
| viability_gap = seq_hidden.new_tensor(1.0 - float(anchor.viability)) | |
| descendant_gap = seq_hidden.new_tensor(1.0 - float(anchor.descendant_coherence or 0.0)) | |
| conflict_signal = 0.55 * contrast + 0.25 * (1.0 - transition_sim) + 0.20 * (1.0 - token_overlap) | |
| plausibility = 0.45 * coherence + 0.35 * tail_support + 0.20 * distance_decay | |
| repair_readiness = 0.60 * pressure + 0.40 * viability_gap | |
| if float(conflict_signal.item()) < 0.18 or float(repair_readiness.item()) < 0.35: | |
| continue | |
| feature_vec = torch.stack( | |
| [ | |
| contrast, | |
| mean_sim, | |
| transition_sim, | |
| coherence, | |
| tail_support, | |
| token_overlap, | |
| distance_decay, | |
| pressure, | |
| viability_gap, | |
| descendant_gap, | |
| ], | |
| dim=0, | |
| ).to(device=seq_hidden.device, dtype=seq_hidden.dtype) | |
| learned_logit = 0.25 * self.score_mlp(feature_vec.unsqueeze(0)).squeeze(0).squeeze(-1) | |
| heuristic_logit = ( | |
| 2.4 * (conflict_signal - 0.35) | |
| + 2.0 * (plausibility - 0.55) | |
| + 1.4 * (repair_readiness - 0.50) | |
| + 0.5 * (descendant_gap - 0.35) | |
| ) | |
| score = torch.sigmoid( | |
| (heuristic_logit + learned_logit) / max(float(self.cfg.anchor_future_proposal_temperature), 1e-6) | |
| ) | |
| candidates.append( | |
| FutureProposalCandidate( | |
| start=offset, | |
| end=offset + length - 1, | |
| repr=window_mean, | |
| score=score, | |
| root_token=root_token, | |
| ) | |
| ) | |
| return self._subsample_candidates(candidates) | |
| def propose( | |
| self, | |
| seq_hidden: torch.Tensor, | |
| seq_ids: torch.Tensor | None, | |
| anchor, | |
| ) -> dict | None: | |
| candidates = self._build_candidates(seq_hidden=seq_hidden, seq_ids=seq_ids, anchor=anchor) | |
| if not candidates: | |
| return None | |
| scores = torch.stack([candidate.score for candidate in candidates], dim=0) | |
| best_score, best_idx = scores.max(dim=0) | |
| if float(best_score.item()) < float(self.cfg.anchor_future_proposal_threshold): | |
| return None | |
| topk = min(int(self.cfg.anchor_future_proposal_topk), len(candidates)) | |
| top_scores, top_idx = torch.topk(scores, k=topk) | |
| top_weights = torch.softmax( | |
| top_scores / max(float(self.cfg.anchor_future_proposal_temperature), 1e-6), | |
| dim=0, | |
| ) | |
| top_repr = torch.stack([candidates[int(idx.item())].repr for idx in top_idx], dim=0) | |
| anchor_repr = anchor.repr.unsqueeze(0).expand_as(top_repr) | |
| fusion_in = torch.cat( | |
| [anchor_repr, top_repr, top_repr - anchor_repr, top_repr * anchor_repr], | |
| dim=-1, | |
| ) | |
| fused_repr = top_repr + float(self.cfg.anchor_future_proposal_residual_scale) * self.repr_delta(fusion_in) | |
| proposal_repr = (top_weights.unsqueeze(-1) * fused_repr).sum(dim=0) | |
| best_candidate = candidates[int(best_idx.item())] | |
| return { | |
| "repr": proposal_repr, | |
| "proposal_type": "future_window_head", | |
| "proposal_score": float(best_score.item()), | |
| "proposal_score_tensor": best_score, | |
| "proposal_span": (best_candidate.start, best_candidate.end), | |
| "proposal_root_token": best_candidate.root_token, | |
| "proposal_candidate_count": len(candidates), | |
| } | |