Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| from typing import Any | |
| import torch | |
| _STOPWORD_HINTS = { | |
| "a", | |
| "an", | |
| "the", | |
| "and", | |
| "or", | |
| "to", | |
| "of", | |
| "in", | |
| "on", | |
| "by", | |
| "for", | |
| "that", | |
| "same", | |
| } | |
| def is_informative_hint_text(text: str) -> bool: | |
| cleaned = text.strip().lower() | |
| if not cleaned: | |
| return False | |
| if not any(char.isalnum() for char in cleaned): | |
| return False | |
| words = [word for word in cleaned.replace("-", " ").split() if word] | |
| if not words: | |
| return False | |
| if len(words) <= 2 and all(word in _STOPWORD_HINTS for word in words): | |
| return False | |
| return True | |
| def decode_span_text(tokenizer: Any, token_ids: list[int]) -> str: | |
| if tokenizer is None: | |
| return " ".join(str(token_id) for token_id in token_ids) | |
| try: | |
| text = tokenizer.decode(token_ids, skip_special_tokens=False) | |
| except TypeError: | |
| text = tokenizer.decode(token_ids) | |
| return text.replace("\n", "\\n") | |
| def safe_decode_token(tokenizer: Any, token_id: int) -> str: | |
| if tokenizer is None: | |
| return str(token_id) | |
| try: | |
| text = tokenizer.decode([token_id], skip_special_tokens=False) | |
| except TypeError: | |
| text = tokenizer.decode([token_id]) | |
| return text.replace("\n", "\\n") | |
| def spans_overlap(span_a: dict[str, Any], span_b: dict[str, Any]) -> bool: | |
| return not (int(span_a["end"]) < int(span_b["start"]) or int(span_b["end"]) < int(span_a["start"])) | |
| def extract_high_influence_spans( | |
| scores: torch.Tensor, | |
| input_ids: torch.Tensor, | |
| tokenizer: Any, | |
| min_score: float, | |
| top_spans: int, | |
| ) -> list[dict[str, Any]]: | |
| selected = [ | |
| idx | |
| for idx, value in enumerate(scores.tolist()) | |
| if float(value) >= float(min_score) | |
| ] | |
| if not selected: | |
| return [] | |
| spans: list[tuple[int, int]] = [] | |
| start = selected[0] | |
| prev = selected[0] | |
| for idx in selected[1:]: | |
| if idx == prev + 1: | |
| prev = idx | |
| continue | |
| spans.append((start, prev)) | |
| start = idx | |
| prev = idx | |
| spans.append((start, prev)) | |
| ranked: list[dict[str, Any]] = [] | |
| for start_idx, end_idx in spans: | |
| span_scores = scores[start_idx : end_idx + 1] | |
| token_ids = [int(token.item()) for token in input_ids[start_idx : end_idx + 1]] | |
| ranked.append( | |
| { | |
| "start": int(start_idx), | |
| "end": int(end_idx), | |
| "length": int(end_idx - start_idx + 1), | |
| "mean_score": float(span_scores.mean().item()), | |
| "max_score": float(span_scores.max().item()), | |
| "token_ids": token_ids, | |
| "text": decode_span_text(tokenizer, token_ids), | |
| } | |
| ) | |
| ranked.sort(key=lambda item: (item["mean_score"], item["length"], item["max_score"]), reverse=True) | |
| return ranked[:top_spans] | |
| def compute_span_anchor_overlap( | |
| future_spans: list[dict[str, Any]], | |
| active_anchor_spans: list[dict[str, int]], | |
| ) -> dict[str, float]: | |
| if not future_spans: | |
| return { | |
| "future_span_overlap_ratio": 0.0, | |
| "anchor_span_overlap_ratio": 0.0, | |
| } | |
| future_overlap = sum( | |
| 1 for span in future_spans if any(spans_overlap(span, anchor) for anchor in active_anchor_spans) | |
| ) | |
| anchor_overlap = sum( | |
| 1 for anchor in active_anchor_spans if any(spans_overlap(anchor, span) for span in future_spans) | |
| ) | |
| return { | |
| "future_span_overlap_ratio": future_overlap / max(len(future_spans), 1), | |
| "anchor_span_overlap_ratio": anchor_overlap / max(len(active_anchor_spans), 1) if active_anchor_spans else 0.0, | |
| } | |
| def build_future_hint_candidates( | |
| future_spans: list[dict[str, Any]], | |
| active_anchor_spans: list[dict[str, int]], | |
| ) -> list[dict[str, Any]]: | |
| hints: list[dict[str, Any]] = [] | |
| for span in future_spans: | |
| if any(spans_overlap(span, anchor_span) for anchor_span in active_anchor_spans): | |
| continue | |
| if not is_informative_hint_text(str(span["text"])): | |
| continue | |
| hints.append( | |
| { | |
| "start": int(span["start"]), | |
| "end": int(span["end"]), | |
| "text": span["text"], | |
| "mean_score": float(span["mean_score"]), | |
| "max_score": float(span["max_score"]), | |
| "length": int(span["length"]), | |
| } | |
| ) | |
| hints.sort(key=lambda item: (item["mean_score"], item["length"], item["max_score"]), reverse=True) | |
| return hints | |
| def build_auxiliary_future_proposals( | |
| hidden: torch.Tensor, | |
| input_ids: torch.Tensor, | |
| future_hint_candidates: list[dict[str, Any]], | |
| tokenizer: Any, | |
| max_candidates: int = 3, | |
| ) -> list[dict[str, Any]]: | |
| proposals: list[dict[str, Any]] = [] | |
| for hint in future_hint_candidates[:max_candidates]: | |
| start = max(0, min(int(hint["start"]), hidden.size(0) - 1)) | |
| end = max(start, min(int(hint["end"]), hidden.size(0) - 1)) | |
| span_hidden = hidden[start : end + 1] | |
| span_ids = [int(token.item()) for token in input_ids[start : end + 1]] | |
| proposals.append( | |
| { | |
| "proposal_type": "future_hint_span", | |
| "proposal_score": float(hint["mean_score"]), | |
| "proposal_span": (start, end), | |
| "proposal_root_token": span_ids[-1] if span_ids else None, | |
| "proposal_text": decode_span_text(tokenizer, span_ids), | |
| "repr": span_hidden.mean(dim=0).detach(), | |
| } | |
| ) | |
| return proposals | |
| def summarize_auxiliary_proposals( | |
| proposal_batches: list[list[dict[str, Any]]], | |
| ) -> dict[str, float]: | |
| counts = [len(batch) for batch in proposal_batches] | |
| all_scores = [float(item["proposal_score"]) for batch in proposal_batches for item in batch] | |
| return { | |
| "proposal_count": int(sum(counts)), | |
| "batch_with_proposals_count": int(sum(1 for count in counts if count > 0)), | |
| "mean_proposal_count_per_batch": float(sum(counts) / max(len(counts), 1)), | |
| "mean_proposal_score": float(sum(all_scores) / max(len(all_scores), 1)) if all_scores else 0.0, | |
| "max_proposal_score": max(all_scores) if all_scores else 0.0, | |
| } | |