OkeyMeta's picture
Add Reframr-RFM-v2-Base release files
52da7b7 verified
import json
import hashlib
import random
import site
import string
import sys
import unicodedata
from dataclasses import dataclass, field
from pathlib import Path
from typing import Sequence
_VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
if _vendor_path.exists():
vendor_text = str(_vendor_path)
if vendor_text not in sys.path:
sys.path.insert(0, vendor_text)
try:
import numpy as np
except ModuleNotFoundError:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.append(user_site)
try:
import numpy as np
except ModuleNotFoundError:
np = None
if np is not None and not hasattr(np, "asarray"):
np = None
from .checkpoint import read_safetensor_file, write_safetensor_file
from .config import ReframrConfig
from .embeddings import EmbeddingModel, fit_ppmi_embedding_from_tokens
from .hippo import AnalyticalMemoryUnit, analytical_embedding_drive, analytical_embedding_drive_fast
from .linalg import Vector, dot, mean, norm, softmax, zeros_vector
from .reservoir import apply_readout, ridge_regression_readout
from .reasoning import TOOL_PROTOCOL_TOKENS, reasoning_prefix
from .sparse_context import HashedSparseAttention
from .ternary import apply_ternary_mask, derive_ternary_mask_from_states
from .tokenizer import NativeTokenizer
ASSOCIATIVE_BLEND = 0.42
TRANSITION_BLEND = 0.08
COPY_BLEND = 0.04
BASE_BLEND = 0.34
FAST_ASSOCIATIVE_BLEND = 0.06
FAST_TRANSITION_BLEND = 0.14
FAST_COPY_BLEND = 0.12
FAST_BASE_BLEND = 0.72
FAST_PREFERENCE_BLEND = 0.15
FAST_ANSWER_BLEND = 0.16
FAST_SOURCE_EVIDENCE_BLEND = 0.52
PROMPT_READOUT_LOGIT_ZSCORE_SCALE = 0.48
PROMPT_START_READOUT_CONFIDENCE_FLOOR = 0.45
ASSOCIATIVE_TOP_K = 12
ANSWER_TOP_K = 48
ANSWER_START_TOP_K = 32
MIN_COMPLETE_ANSWER_WORDS = 6
MIN_COMPLETE_MULTI_SENTENCE_WORDS = 4
ANSWER_SEQUENCE_MATCH_FLOOR = 0.27
ANSWER_START_CONFIDENCE_FLOOR = 0.45
ANSWER_START_MATCH_SUPPORT_FLOOR = 0.18
ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR = 0.45
ANSWER_SEQUENCE_LOCK_FLOOR = 0.55
ANSWER_SEQUENCE_SPIKE_CONFIDENCE = 0.80
READOUT_LOGIT_ZSCORE_SCALE = 0.22
TRACE_IDENTITY_SCALE = 0.78
TRACE_IDENTITY_HASHES = (
(1103515245, 12345, 214013, 2531011),
(1664525, 1013904223, 22695477, 1),
(69069, 362437, 134775813, 17),
(134775813, 97, 1103515245, 31),
(22695477, 911, 1664525, 73),
(214013, 2531011, 69069, 19),
(48271, 0, 69621, 11),
(16807, 37, 40692, 101),
(279470273, 173, 1299709, 53),
(39916801, 29, 2147483629, 7),
)
PROMPT_ENVELOPE_TERMS = frozenset(
{"system", "instruction", "user", "human", "assistant", "question", "answer"}
)
NGRAM_KEY_SEPARATOR = "\u0001"
TRANSITION_ORDERS = (10, 8, 6, 5, 4, 3, 2, 1)
DEFAULT_GENERATION_TEMPERATURE = 0.82
DEFAULT_GENERATION_TOP_K = 24
DEFAULT_GENERATION_TOP_P = 0.92
DEFAULT_REPETITION_PENALTY = 1.18
ANSWER_SEQUENCE_MAX_TOKENS = 192
ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT = 8192
ANSWER_SEQUENCE_VARIATION_TEMPERATURE = 0.65
ANSWER_SEQUENCE_VARIATION_MATCH_LIMIT = 4
ANSWER_SEQUENCE_CREATIVE_TEMPERATURE = 1.10
ANSWER_REPLAY_PREFIX_TEMPERATURE = 0.95
ANSWER_REPLAY_PREFIX_MIN_TOKENS = 64
ANSWER_REPLAY_PREFIX_PENALTY = 0.18
CREATIVE_EARLY_POOL_TEMPERATURE = 1.05
CREATIVE_EARLY_POOL_WORD_LIMIT = 6
CREATIVE_EARLY_POOL_MAX = 8
TOOL_CALL_CONTEXT_TERMS = frozenset(
{
"current",
"latest",
"today",
"yesterday",
"tonight",
"now",
"fresh",
"recent",
"web",
"search",
"real-time",
"price",
"weather",
"election",
"news",
"official",
"result",
"live",
}
)
RUNTIME_GENERATION_HISTORY_LIMIT = 8
AVOID_SEQUENCE_MIN_TOKENS = 6
WORD_COMPLETION_OVERFLOW_TOKENS = 16
ANSWER_FINGERPRINT_WORDS = 4
SPARSE_CONTEXT_MIN_TOKENS = 16
SPARSE_CONTEXT_TOP_K = 64
SPARSE_CONTEXT_HASH_BITS = 12
SPARSE_CONTEXT_PROBE_RADIUS = 1
SPARSE_CONTEXT_CANDIDATE_MULTIPLIER = 16
SPARSE_CONTEXT_TRACE_BLEND = 0.35
RUNTIME_ARRAY_DTYPE = np.float32 if np is not None else None
@dataclass(frozen=True, slots=True)
class CharacterCountFact:
character: str
word: str
count: int
surface_seed: int
focused: bool
@dataclass(frozen=True, slots=True)
class GenerationTokenMeta:
rendered: str
stripped: str
starts_new_word: bool
punctuation_piece: bool
structural_punctuation: bool
structural_symbol: bool
word_joiner: bool
alphanumeric: str
common_connector: bool
def _normalize_vector(values: Vector) -> Vector:
total = sum(values)
if total <= 0.0:
return [0.0 for _ in values]
return [value / total for value in values]
def _encode_ngram_key(tokens: tuple[str, ...]) -> str:
return NGRAM_KEY_SEPARATOR.join(tokens)
def _decode_ngram_key(key: str) -> tuple[str, ...]:
return tuple(part for part in key.split(NGRAM_KEY_SEPARATOR) if part)
def _last_index(values: list[str], target: str) -> int | None:
for index in range(len(values) - 1, -1, -1):
if values[index] == target:
return index
return None
def _first_index(values: list[str], target: str) -> int | None:
for index, value in enumerate(values):
if value == target:
return index
return None
@dataclass(slots=True)
class DecodeState:
hidden_states: list[Vector]
context_traces: list[Vector]
combined_state: Vector
context_tokens: list[str]
answer_anchor_state: Vector | None = None
answer_matches: list[tuple[float, int, int]] | None = None
answer_start_matches: list[tuple[float, int, int]] | None = None
answer_sequence_matches: list[tuple[float, int, int]] | None = None
prompt_answer_prior: object | None = None
prompt_answer_start_prior: object | None = None
@dataclass(slots=True)
class ReframrModel:
config: ReframrConfig
tokenizer: NativeTokenizer | None = None
embedding_model: EmbeddingModel | None = None
memory_units: list[AnalyticalMemoryUnit] | None = None
ternary_scale: float = 1.0
ternary_mask: list[int] | None = None
ternary_mask_array: object | None = None
readout_weights: list[list[float]] | None = None
readout_weights_array: object | None = None
readout_bias: Vector | None = None
readout_bias_array: object | None = None
prompt_answer_weights: list[list[float]] | None = None
prompt_answer_weights_array: object | None = None
prompt_answer_bias: Vector | None = None
prompt_answer_bias_array: object | None = None
prompt_answer_start_weights: list[list[float]] | None = None
prompt_answer_start_weights_array: object | None = None
prompt_answer_start_bias: Vector | None = None
prompt_answer_start_bias_array: object | None = None
trace_token_weights: Vector | None = None
trace_token_weights_array: object | None = None
trace_embedding_table_array: object | None = None
preference_bias: Vector | None = None
preference_bias_array: object | None = None
preference_valid_mask_array: object | None = None
state_offset: Vector | None = None
state_offset_array: object | None = None
associative_keys: list[Vector] | None = None
associative_keys_array: object | None = None
associative_key_norms: list[float] | None = None
associative_key_norms_array: object | None = None
associative_values: list[int] | None = None
associative_values_array: object | None = None
associative_valid_mask_array: object | None = None
answer_keys: list[Vector] | None = None
answer_keys_array: object | None = None
answer_key_norms: list[float] | None = None
answer_key_norms_array: object | None = None
answer_similarity_keys_array: object | None = None
answer_similarity_key_norms_array: object | None = None
answer_similarity_mask_array: object | None = None
answer_values: list[int] | None = None
answer_values_array: object | None = None
answer_valid_mask_array: object | None = None
answer_start_keys: list[Vector] | None = None
answer_start_keys_array: object | None = None
answer_start_key_norms: list[float] | None = None
answer_start_key_norms_array: object | None = None
answer_start_similarity_keys_array: object | None = None
answer_start_similarity_key_norms_array: object | None = None
answer_start_values: list[int] | None = None
answer_start_values_array: object | None = None
answer_start_valid_mask_array: object | None = None
answer_sequence_keys: list[Vector] | None = None
answer_sequence_keys_array: object | None = None
answer_sequence_key_norms: list[float] | None = None
answer_sequence_key_norms_array: object | None = None
answer_sequence_similarity_keys_array: object | None = None
answer_sequence_similarity_key_norms_array: object | None = None
answer_sequence_prompt_tokens: list[list[int]] | None = None
answer_sequence_prompt_tokens_array: object | None = None
answer_sequence_tokens: list[list[int]] | None = None
answer_sequence_tokens_array: object | None = None
answer_sequence_token_id_rows: list[list[int]] | None = None
answer_sequence_prompt_weight_maps: list[dict[int, float]] | None = None
answer_sequence_prompt_weight_norms: list[float] | None = None
answer_sequence_prompt_bigram_sets: list[set[tuple[int, int]]] | None = None
answer_sequence_prompt_trigram_sets: list[set[tuple[int, int, int]]] | None = None
answer_sequence_prompt_number_sets: list[set[str]] | None = None
answer_sequence_prompt_inverted_index: dict[int, list[int]] | None = None
answer_sequence_prompt_specificity: dict[int, float] | None = None
prompt_overlap_valid_token_mask_array: object | None = None
answer_fingerprint_hashes: set[tuple[int, ...]] | None = None
answer_fingerprint_token_lengths: set[int] | None = None
answer_fingerprint_token_sequences_by_length: dict[int, set[tuple[int, ...]]] | None = None
answer_sequence_prefixes_by_length: dict[int, set[tuple[int, ...]]] | None = None
transition_tables: dict[int, dict[tuple[str, ...], dict[str, float]]] | None = None
transition_id_tables: dict[int, dict[tuple[int, ...], tuple[object, object]]] | None = None
transition_tensor_cache: dict[str, object] | None = None
transition_built_orders: set[int] | None = None
generation_token_meta_cache: dict[str, GenerationTokenMeta] | None = None
runtime_generation_history: dict[str, list[str]] = field(default_factory=dict, repr=False)
def fit(self, text: str) -> "ReframrModel":
self.generation_token_meta_cache = None
self.answer_sequence_prefixes_by_length = None
self.tokenizer = NativeTokenizer.train(
text,
vocab_size=self.config.tokenizer_vocab_size,
min_pair_frequency=self.config.tokenizer_min_pair_frequency,
lowercase=self.config.lowercase,
)
tokens = self.tokenizer.encode(text)
if len(tokens) < 2:
raise ValueError("REFRAMR needs at least two tokens to derive a next-token readout.")
self.embedding_model = fit_ppmi_embedding_from_tokens(
tokens,
embedding_dim=self.config.embedding_dim,
window_size=self.config.window_size,
min_frequency=self.config.min_frequency,
max_vocab=self.config.max_vocab,
required_tokens=self.tokenizer.vocab,
)
self.memory_units = [
AnalyticalMemoryUnit(self.config.state_dim, timescale)
for timescale in self.config.timescales
]
token_counts: dict[str, float] = {}
for token in tokens:
token_counts[token] = token_counts.get(token, 0.0) + 1.0
self.trace_token_weights = self._derive_trace_token_weights_from_counts(token_counts)
raw_states, targets, target_ids = self._collect_training_examples(tokens)
self.ternary_scale, self.ternary_mask = derive_ternary_mask_from_states(raw_states)
analytical_states = [
apply_ternary_mask(state, self.ternary_mask, self.ternary_scale)
for state in raw_states
]
self.associative_keys = [state[:] for state in analytical_states]
self.associative_key_norms = [norm(state) for state in analytical_states]
self.associative_values = target_ids[:]
self.answer_keys = []
self.answer_key_norms = []
self.answer_values = []
self.answer_start_keys = []
self.answer_start_key_norms = []
self.answer_start_values = []
self.answer_sequence_keys = []
self.answer_sequence_key_norms = []
self.answer_sequence_prompt_tokens = []
self.answer_sequence_tokens = []
self.prompt_answer_weights = []
self.prompt_answer_bias = [0.0 for _ in self.embedding_model.id_to_token]
self.prompt_answer_start_weights = []
self.prompt_answer_start_bias = [0.0 for _ in self.embedding_model.id_to_token]
self.transition_tables = self._build_transition_tables(tokens)
self._fit_answer_memory_from_text(text)
self._refresh_answer_fingerprint_hashes()
self.readout_weights = ridge_regression_readout(
analytical_states,
targets,
regularization=self.config.regularization,
)
self.readout_bias = [0.0 for _ in self.embedding_model.id_to_token]
self.preference_bias = [0.0 for _ in self.embedding_model.id_to_token]
self.state_offset = [0.0 for _ in analytical_states[0]] if analytical_states else []
self._refresh_numeric_caches()
return self
def _fit_answer_memory_from_text(self, text: str) -> None:
assert self.tokenizer is not None
assert self.embedding_model is not None
if (
self.answer_keys is None
or self.answer_key_norms is None
or self.answer_values is None
or self.answer_start_keys is None
or self.answer_start_key_norms is None
or self.answer_start_values is None
or self.answer_sequence_keys is None
or self.answer_sequence_key_norms is None
or self.answer_sequence_prompt_tokens is None
or self.answer_sequence_tokens is None
):
return
for line in text.splitlines():
if "<answer>" not in line:
continue
prompt_text, answer_text = line.split("<answer>", 1)
prompt_text = prompt_text.strip()
answer_text = answer_text.strip()
if not prompt_text or not answer_text:
continue
prompt_tokens = self.tokenizer.encode(prompt_text) + ["<answer>"]
answer_tokens = [
token
for token in self.tokenizer.encode(answer_text)
if token in self.embedding_model.token_to_id
and (
token not in self.tokenizer.special_tokens
or token in TOOL_PROTOCOL_TOKENS
)
]
if not prompt_tokens or not answer_tokens:
continue
key = self._encode_context(prompt_tokens)
key_norm = norm(key)
if key_norm <= 0.0:
continue
answer_ids = [
self.embedding_model.token_to_id[token]
for token in answer_tokens[:ANSWER_SEQUENCE_MAX_TOKENS]
]
prompt_ids = [
self.embedding_model.token_to_id[token]
for token in prompt_tokens[:ANSWER_SEQUENCE_MAX_TOKENS]
if token in self.embedding_model.token_to_id
and (
token not in self.tokenizer.special_tokens
or token in TOOL_PROTOCOL_TOKENS
)
]
if not answer_ids:
continue
self.answer_keys.append(key[:])
self.answer_key_norms.append(key_norm)
self.answer_values.append(answer_ids[0])
self.answer_start_keys.append(key[:])
self.answer_start_key_norms.append(key_norm)
self.answer_start_values.append(answer_ids[0])
self.answer_sequence_keys.append(key[:])
self.answer_sequence_key_norms.append(key_norm)
self.answer_sequence_prompt_tokens.append(
prompt_ids
+ [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(prompt_ids))]
)
self.answer_sequence_tokens.append(
answer_ids
+ [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(answer_ids))]
)
def predict_next_distribution(
self,
context: str,
*,
reasoning_mode: str | None = None,
) -> dict[str, float]:
self._require_fit()
assert self.tokenizer is not None
assert self.embedding_model is not None
probabilities = self.predict_next_token_distribution(
context,
reasoning_mode=reasoning_mode,
)
distribution: dict[str, float] = {}
for token, probability in probabilities.items():
rendered = self._render_token(token)
distribution[rendered] = distribution.get(rendered, 0.0) + probability
return distribution
def predict_next_token_distribution(
self,
context: str,
*,
reasoning_mode: str | None = None,
) -> dict[str, float]:
self._require_fit()
assert self.tokenizer is not None
assert self.embedding_model is not None
assert self.readout_weights is not None
active_mode = reasoning_mode or self.config.default_reasoning_profile
context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context)
return self._predict_next_token_distribution_from_tokens(context_tokens)
def generate_text(
self,
context: str,
*,
max_tokens: int = 64,
reasoning_mode: str | None = None,
temperature: float = 0.0,
top_k: int = DEFAULT_GENERATION_TOP_K,
top_p: float = DEFAULT_GENERATION_TOP_P,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
avoid_texts: Sequence[str] | None = None,
) -> str:
character_count_response = self._character_count_response(
context,
temperature=temperature,
)
if character_count_response is not None:
return character_count_response
self._require_fit()
self._ensure_numeric_caches()
assert self.tokenizer is not None
runtime_avoid_texts = self._runtime_avoid_texts(
context,
avoid_texts,
temperature=temperature,
)
avoid_token_sequences = self._avoid_text_token_sequences(runtime_avoid_texts)
if (
np is not None
and self.readout_weights_array is not None
and self.embedding_model is not None
and len(self.embedding_model.id_to_token) >= 1024
):
generated_text = self._generate_text_fast(
context,
max_tokens=max_tokens,
reasoning_mode=reasoning_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
avoid_token_sequences=avoid_token_sequences,
)
self._remember_runtime_generation(
context,
generated_text,
temperature=temperature,
)
return generated_text
active_mode = reasoning_mode or self.config.default_reasoning_profile
_, context_tokens = self._generation_prompt_tokens(context, active_mode)
decode_state = self._build_decode_state(context_tokens)
generated_tokens: list[str] = []
for _ in range(max_tokens):
distribution, _ = self._score_next_token_from_state(
decode_state,
include_trace=False,
generated_tokens=generated_tokens,
temperature=temperature,
avoid_token_sequences=avoid_token_sequences,
)
forced_source_token = self._source_evidence_next_token(
decode_state.context_tokens,
generated_tokens,
)
next_token = forced_source_token or self._select_generation_token(
distribution,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
avoid_token_sequences=avoid_token_sequences,
preserve_dominant_candidates=(
self._answer_decode_has_continuation(decode_state, generated_tokens)
or self._source_evidence_has_continuation(
decode_state.context_tokens,
generated_tokens,
)
),
)
if not next_token:
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
if self._should_stop_answer_sequence(decode_state, generated_tokens):
break
if self._should_stop_after_answer_path_drift(decode_state, generated_tokens):
break
if self._source_evidence_is_complete(decode_state.context_tokens, generated_tokens):
break
if (
self._should_stop_generation(generated_tokens)
and not self._answer_decode_has_continuation(decode_state, generated_tokens)
and not self._source_evidence_has_continuation(
decode_state.context_tokens,
generated_tokens,
)
):
break
overflow_budget = max(WORD_COMPLETION_OVERFLOW_TOKENS, max_tokens)
while generated_tokens and overflow_budget > 0:
has_answer_continuation = self._answer_decode_has_continuation(
decode_state,
generated_tokens,
)
has_source_continuation = self._source_evidence_has_continuation(
decode_state.context_tokens,
generated_tokens,
)
if (
self._starts_new_word(generated_tokens[-1])
and not has_answer_continuation
and not has_source_continuation
):
break
distribution, _ = self._score_next_token_from_state(
decode_state,
include_trace=False,
generated_tokens=generated_tokens,
temperature=temperature,
avoid_token_sequences=avoid_token_sequences,
)
forced_source_token = self._source_evidence_next_token(
decode_state.context_tokens,
generated_tokens,
)
next_token = forced_source_token or self._select_generation_token(
distribution,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
avoid_token_sequences=avoid_token_sequences,
preserve_dominant_candidates=has_answer_continuation
or has_source_continuation,
)
if not next_token:
break
if (
self._starts_new_word(next_token)
and not has_answer_continuation
and not has_source_continuation
):
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
overflow_budget -= 1
generated_text = self._finalize_generated_text(
self._normalize_generated_tool_protocol_text(
self._decode_tokens(generated_tokens),
context=context,
)
)
self._remember_runtime_generation(
context,
generated_text,
temperature=temperature,
)
return generated_text
@staticmethod
def _character_count_fact(context: str) -> CharacterCountFact | None:
normalized = unicodedata.normalize("NFKC", context).strip()
tokens = ReframrModel._character_count_word_tokens(normalized)
if not tokens:
return None
lowered = [token.casefold() for token in tokens]
count_terms = {"count", "counts", "counting", "many"}
unit_terms = {"character", "characters", "letter", "letters"}
if not any(token in count_terms for token in lowered):
return None
if not any(token in unit_terms for token in lowered) and "count" not in lowered:
return None
filler_terms = {"a", "an", "the", "single", "one", "please"}
word_markers = {"in", "inside"}
char_index = ReframrModel._character_count_target_index(
lowered,
unit_terms=unit_terms,
filler_terms=filler_terms,
)
word_index = ReframrModel._character_count_word_index(
lowered,
char_index=char_index,
filler_terms=filler_terms,
word_markers=word_markers,
)
if char_index is None or word_index is None:
return None
character = tokens[char_index]
word = tokens[word_index]
if len(character) != 1 or not word:
return None
order_offset = 0 if char_index < word_index else 1
surface_seed = ((char_index + 1) * 7 + (word_index + 1) * 3 + len(tokens) + order_offset) % 4
structural_terms = (
count_terms
| unit_terms
| filler_terms
| word_markers
| {
"for",
"of",
"to",
"how",
"do",
"does",
"there",
"are",
"is",
"appear",
"appears",
"times",
"word",
}
)
extra_content_tokens = [
token
for index, token in enumerate(lowered)
if index not in {char_index, word_index}
and token not in structural_terms
]
return CharacterCountFact(
character=character,
word=word,
count=word.casefold().count(character.casefold()),
surface_seed=surface_seed,
focused=not extra_content_tokens,
)
@staticmethod
def _character_count_word_tokens(text: str) -> list[str]:
tokens: list[str] = []
current: list[str] = []
for character in text:
if character != "_" and character.isalnum():
current.append(character)
continue
if current:
tokens.append("".join(current))
current = []
if current:
tokens.append("".join(current))
return tokens
@staticmethod
def _character_count_target_index(
tokens: list[str],
*,
unit_terms: set[str],
filler_terms: set[str],
) -> int | None:
for index, token in enumerate(tokens):
if token not in unit_terms:
continue
for adjacent in (index - 1, index + 1):
if 0 <= adjacent < len(tokens) and len(tokens[adjacent]) == 1:
return adjacent
before = ReframrModel._nearest_content_index(tokens, index - 1, -1, filler_terms)
after = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms)
for candidate in (before, after):
if candidate is not None and len(tokens[candidate]) == 1:
return candidate
for index, token in enumerate(tokens):
if token not in {"count", "counts", "counting"}:
continue
candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms)
if candidate is not None and tokens[candidate] in unit_terms:
candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms)
if candidate is not None and len(tokens[candidate]) == 1:
return candidate
return None
@staticmethod
def _character_count_word_index(
tokens: list[str],
*,
char_index: int | None,
filler_terms: set[str],
word_markers: set[str],
) -> int | None:
for index, token in enumerate(tokens):
if token != "word":
continue
candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms)
if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1:
return candidate
for index, token in enumerate(tokens):
if token not in word_markers:
continue
candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms)
if candidate is not None and tokens[candidate] == "word":
candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms)
if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1:
return candidate
skipped_terms = {
"how",
"many",
"do",
"does",
"count",
"counts",
"counting",
"letter",
"letters",
"character",
"characters",
"word",
"there",
"are",
"is",
"appear",
"appears",
"times",
} | filler_terms | word_markers
for index in range(len(tokens) - 1, -1, -1):
if index == char_index:
continue
if len(tokens[index]) <= 1 or tokens[index] in skipped_terms:
continue
return index
return None
@staticmethod
def _nearest_content_index(
tokens: list[str],
start: int,
direction: int,
skipped_terms: set[str],
) -> int | None:
index = start
while 0 <= index < len(tokens):
if tokens[index] not in skipped_terms:
return index
index += direction
return None
@classmethod
def _character_count_response(cls, context: str, *, temperature: float = 0.0) -> str | None:
fact = cls._character_count_fact(context)
if fact is None:
return None
if not fact.focused:
return None
return cls._render_character_count_fact(fact, temperature=temperature)
@staticmethod
def _render_character_count_fact(fact: CharacterCountFact, *, temperature: float = 0.0) -> str:
character_label = f"'{fact.character}'"
word_label = f"'{fact.word}'"
character_noun = "character" if fact.count == 1 else "characters"
return f"{word_label} has {fact.count} {character_label} {character_noun}."
@classmethod
def _runtime_source_grounded_response(cls, context: str) -> str | None:
return None
@classmethod
def _runtime_source_records(cls, context: str) -> list[tuple[str, str, str]]:
records: list[tuple[str, str, str]] = []
marker = "<source>"
search_from = 0
while True:
source_start = context.find(marker, search_from)
if source_start < 0:
break
content_start = source_start + len(marker)
content_end = cls._runtime_source_record_end(context, content_start)
raw_record = context[content_start:content_end].strip()
record = cls._parse_runtime_source_record(raw_record)
if record is not None:
records.append(record)
search_from = max(content_end, content_start + 1)
return records
@staticmethod
def _runtime_source_record_end(context: str, start: int) -> int:
boundaries = [
position
for marker in (
"\n",
"<source>",
"<tool_call>",
"<tool_result>",
"<final>",
"<answer>",
"<reason>",
)
if (position := context.find(marker, start)) >= 0
]
return min(boundaries) if boundaries else len(context)
@staticmethod
def _parse_runtime_source_record(raw_record: str) -> tuple[str, str, str] | None:
if not raw_record:
return None
pieces = [piece.strip() for piece in raw_record.split("|", 2)]
if len(pieces) >= 3:
title, url, snippet = pieces[0], pieces[1], pieces[2]
else:
title, url, snippet = "the provided source", "", pieces[-1]
title = ReframrModel._clean_runtime_source_field(title) or "the provided source"
url = ReframrModel._clean_runtime_source_field(url)
snippet = ReframrModel._clean_runtime_source_field(snippet)
if not snippet:
return None
return title, url, snippet
@staticmethod
def _clean_runtime_source_field(text: str) -> str:
normalized = unicodedata.normalize("NFKC", text)
cleaned = " ".join(normalized.split())
return cleaned.strip(" \t\r\n|")
def _generate_text_fast(
self,
context: str,
*,
max_tokens: int,
reasoning_mode: str | None,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
avoid_token_sequences: Sequence[Sequence[str]] | None = None,
) -> str:
assert self.tokenizer is not None
active_mode = reasoning_mode or self.config.default_reasoning_profile
_, context_tokens = self._generation_prompt_tokens(context, active_mode)
decode_state = self._build_decode_state(context_tokens)
generated_tokens: list[str] = []
for _ in range(max_tokens):
probabilities, _ = self._score_next_token_array_from_state(
decode_state,
include_associative=not generated_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
avoid_token_sequences=avoid_token_sequences,
)
forced_source_token = self._source_evidence_next_token(
decode_state.context_tokens,
generated_tokens,
)
next_token = forced_source_token or self._select_generation_token_from_array(
probabilities,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
avoid_token_sequences=avoid_token_sequences,
preserve_dominant_candidates=(
self._answer_decode_has_continuation(decode_state, generated_tokens)
or self._source_evidence_has_continuation(
decode_state.context_tokens,
generated_tokens,
)
),
)
if not next_token:
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
if self._should_stop_answer_sequence(decode_state, generated_tokens):
break
if self._should_stop_after_answer_path_drift(decode_state, generated_tokens):
break
if self._source_evidence_is_complete(decode_state.context_tokens, generated_tokens):
break
if (
self._should_stop_generation(generated_tokens)
and not self._answer_decode_has_continuation(decode_state, generated_tokens)
and not self._source_evidence_has_continuation(
decode_state.context_tokens,
generated_tokens,
)
):
break
overflow_budget = max(WORD_COMPLETION_OVERFLOW_TOKENS, max_tokens)
while generated_tokens and overflow_budget > 0:
has_answer_continuation = self._answer_decode_has_continuation(
decode_state,
generated_tokens,
)
has_source_continuation = self._source_evidence_has_continuation(
decode_state.context_tokens,
generated_tokens,
)
if (
self._starts_new_word(generated_tokens[-1])
and not has_answer_continuation
and not has_source_continuation
):
break
probabilities, _ = self._score_next_token_array_from_state(
decode_state,
include_associative=False,
generated_tokens=generated_tokens,
temperature=temperature,
avoid_token_sequences=avoid_token_sequences,
)
forced_source_token = self._source_evidence_next_token(
decode_state.context_tokens,
generated_tokens,
)
next_token = forced_source_token or self._select_generation_token_from_array(
probabilities,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
avoid_token_sequences=avoid_token_sequences,
preserve_dominant_candidates=has_answer_continuation
or has_source_continuation,
)
if not next_token:
break
if (
self._starts_new_word(next_token)
and not has_answer_continuation
and not has_source_continuation
):
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
overflow_budget -= 1
return self._finalize_generated_text(
self._normalize_generated_tool_protocol_text(
self._decode_tokens(generated_tokens),
context=context,
)
)
def trace_next_token(
self,
context: str,
*,
reasoning_mode: str | None = None,
top_k: int = 5,
) -> dict[str, object]:
self._require_fit()
assert self.tokenizer is not None
active_mode = reasoning_mode or self.config.default_reasoning_profile
context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context)
_, trace = self._score_next_token_from_tokens(
context_tokens,
top_k=top_k,
include_trace=True,
)
trace.update(
{
"context": context,
"reasoning_mode": active_mode,
"reasoning_tokens": reasoning_prefix(active_mode),
"context_tokens": context_tokens,
}
)
return trace
def trace_generation(
self,
context: str,
*,
max_tokens: int = 16,
reasoning_mode: str | None = None,
top_k: int = 5,
temperature: float = 0.0,
top_p: float = DEFAULT_GENERATION_TOP_P,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
) -> dict[str, object]:
character_count_response = self._character_count_response(
context,
temperature=temperature,
)
if character_count_response is not None:
active_mode = reasoning_mode or self.config.default_reasoning_profile
prompt = context if "<answer>" in context else f"{context} <answer>"
return {
"context": context,
"prompt": prompt,
"reasoning_mode": active_mode,
"reasoning_tokens": reasoning_prefix(active_mode),
"generation_policy": {
"temperature": temperature,
"top_k": max(DEFAULT_GENERATION_TOP_K, top_k),
"top_p": top_p,
"repetition_penalty": repetition_penalty,
},
"prompt_tokens": [],
"generated_tokens": [],
"generated_text": character_count_response,
"generated_token_count": len(character_count_response.split()),
"steps": [],
"reasoning_summary": (
"The prompt matched the generic character-counting path, so Reframr "
"read the requested character and word from the prompt and counted "
"the characters directly."
),
}
self._require_fit()
assert self.tokenizer is not None
active_mode = reasoning_mode or self.config.default_reasoning_profile
prompt, context_tokens = self._generation_prompt_tokens(context, active_mode)
decode_state = self._build_decode_state(context_tokens)
prompt_tokens = decode_state.context_tokens[:]
generated_tokens: list[str] = []
steps: list[dict[str, object]] = []
for step_index in range(1, max_tokens + 1):
distribution, trace = self._score_next_token_from_state(
decode_state,
top_k=top_k,
include_trace=True,
generated_tokens=generated_tokens,
temperature=temperature,
)
forced_source_token = self._source_evidence_next_token(
decode_state.context_tokens,
generated_tokens,
)
next_token = forced_source_token or self._select_generation_token(
distribution,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=max(DEFAULT_GENERATION_TOP_K, top_k),
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=(
self._answer_decode_has_continuation(decode_state, generated_tokens)
or self._source_evidence_has_continuation(
decode_state.context_tokens,
generated_tokens,
)
),
)
if not next_token:
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
trace["step"] = step_index
trace["chosen_token"] = next_token
trace["chosen_text"] = self._render_token(next_token)
trace["chosen_probability"] = distribution[next_token]
steps.append(trace)
if self._should_stop_answer_sequence(decode_state, generated_tokens):
break
if self._should_stop_after_answer_path_drift(decode_state, generated_tokens):
break
if self._source_evidence_is_complete(decode_state.context_tokens, generated_tokens):
break
if (
self._should_stop_generation(generated_tokens)
and not self._answer_decode_has_continuation(decode_state, generated_tokens)
and not self._source_evidence_has_continuation(
decode_state.context_tokens,
generated_tokens,
)
):
break
overflow_budget = max(WORD_COMPLETION_OVERFLOW_TOKENS, max_tokens)
while generated_tokens and overflow_budget > 0:
has_answer_continuation = self._answer_decode_has_continuation(
decode_state,
generated_tokens,
)
has_source_continuation = self._source_evidence_has_continuation(
decode_state.context_tokens,
generated_tokens,
)
if (
self._starts_new_word(generated_tokens[-1])
and not has_answer_continuation
and not has_source_continuation
):
break
distribution, trace = self._score_next_token_from_state(
decode_state,
top_k=top_k,
include_trace=True,
generated_tokens=generated_tokens,
temperature=temperature,
)
forced_source_token = self._source_evidence_next_token(
decode_state.context_tokens,
generated_tokens,
)
next_token = forced_source_token or self._select_generation_token(
distribution,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=max(DEFAULT_GENERATION_TOP_K, top_k),
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=has_answer_continuation
or has_source_continuation,
)
if not next_token:
break
if (
self._starts_new_word(next_token)
and not has_answer_continuation
and not has_source_continuation
):
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
trace["step"] = len(steps) + 1
trace["chosen_token"] = next_token
trace["chosen_text"] = self._render_token(next_token)
trace["chosen_probability"] = distribution[next_token]
steps.append(trace)
if self._should_stop_answer_sequence(decode_state, generated_tokens):
break
if self._should_stop_after_answer_path_drift(decode_state, generated_tokens):
break
overflow_budget -= 1
return {
"context": context,
"prompt": prompt,
"reasoning_mode": active_mode,
"reasoning_tokens": reasoning_prefix(active_mode),
"generation_policy": {
"temperature": temperature,
"top_k": max(DEFAULT_GENERATION_TOP_K, top_k),
"top_p": top_p,
"repetition_penalty": repetition_penalty,
},
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
"generated_text": self._finalize_generated_text(
self._normalize_generated_tool_protocol_text(
self._decode_tokens(generated_tokens),
context=context,
)
),
"generated_token_count": len(generated_tokens),
"steps": steps,
}
def _generation_prompt_tokens(self, context: str, active_mode: str) -> tuple[str, list[str]]:
assert self.tokenizer is not None
prompt = context if "<answer>" in context else f"{context} <answer>"
prefix = reasoning_prefix(active_mode)
prompt_tokens = self.tokenizer.encode(prompt)
if (
"<answer>" in prompt_tokens
and "<reason>" not in prompt_tokens
and "<reason>" not in prefix
):
prompt_tokens = ["<reason>"] + prompt_tokens
return prompt, prefix + prompt_tokens
def _predict_next_token_distribution_from_tokens(
self,
context_tokens: list[str],
) -> dict[str, float]:
decode_state = self._build_decode_state(context_tokens)
return self._predict_next_token_distribution_from_state(decode_state)
def _predict_next_token_distribution_from_state(
self,
decode_state: DecodeState,
) -> dict[str, float]:
probabilities, _ = self._score_next_token_from_state(
decode_state,
include_trace=False,
)
return probabilities
@staticmethod
def _answer_memory_is_confident(
*,
answer_sequence_match_confidence: float,
answer_start_confidence: float,
generated_count: int,
) -> bool:
if generated_count > 0:
return answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR
if answer_sequence_match_confidence >= ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR:
return True
if answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR:
return True
if answer_start_confidence >= ANSWER_START_CONFIDENCE_FLOOR + ANSWER_SEQUENCE_MATCH_FLOOR:
return True
return (
answer_sequence_match_confidence >= ANSWER_START_MATCH_SUPPORT_FLOOR
and answer_start_confidence >= ANSWER_START_CONFIDENCE_FLOOR
and answer_start_confidence <= answer_sequence_match_confidence + ANSWER_START_CONFIDENCE_FLOOR
)
@staticmethod
def _answer_sequence_should_lock(
*,
answer_sequence_confidence: float,
answer_sequence_match_confidence: float,
has_answer_sequence_prior: bool,
) -> bool:
if not has_answer_sequence_prior or answer_sequence_confidence <= 0.0:
return False
if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR:
return True
if (
answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR
and answer_sequence_confidence >= 0.30
and answer_sequence_confidence <= 0.65
):
return True
return (
answer_sequence_match_confidence >= ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR
and answer_sequence_confidence <= ANSWER_SEQUENCE_SPIKE_CONFIDENCE
)
def _prompt_start_readout_is_confident(
self,
prior: object,
tokens: Sequence[str] | None = None,
) -> bool:
if self.tokenizer is None:
return False
if tokens is None:
if self.embedding_model is None:
return False
tokens = self.embedding_model.id_to_token
values = prior.tolist() if hasattr(prior, "tolist") else list(prior)
if not values or not tokens:
return False
limit = min(len(values), len(tokens))
if limit <= 0:
return False
best_index = max(range(limit), key=lambda index: float(values[index]))
best_probability = float(values[best_index])
if best_probability < PROMPT_START_READOUT_CONFIDENCE_FLOOR:
return False
meta = self._generation_token_meta(tokens[best_index])
return (
meta.starts_new_word
and bool(meta.alphanumeric)
and not meta.structural_punctuation
and not meta.structural_symbol
)
def _locked_answer_sequence_matches(
self,
matches: list[tuple[float, int, int]],
*,
generated_tokens: list[str],
temperature: float,
answer_sequence_confidence: float,
answer_sequence_match_confidence: float,
) -> list[tuple[float, int, int]]:
if not matches:
return []
if generated_tokens:
aligned_matches = [
match
for match in matches[:ANSWER_START_TOP_K]
if self._answer_sequence_match_has_continuation(
match,
generated_tokens,
)
]
return aligned_matches[:ANSWER_SEQUENCE_VARIATION_MATCH_LIMIT] or matches[:1]
best_similarity = matches[0][0]
near_match_floor = max(ANSWER_SEQUENCE_MATCH_FLOOR, best_similarity - 0.08)
varied = [
match
for match in matches[:ANSWER_SEQUENCE_VARIATION_MATCH_LIMIT]
if match[0] >= near_match_floor
]
if (
temperature < ANSWER_SEQUENCE_VARIATION_TEMPERATURE
and answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR
and len(varied) <= 1
):
return matches[:1]
return varied or matches[:1]
@staticmethod
def _answer_sequence_matches_are_ambiguous(
matches: Sequence[tuple[float, int, int]],
) -> bool:
if len(matches) < 2:
return False
best_similarity = float(matches[0][0])
if best_similarity < ANSWER_SEQUENCE_MATCH_FLOOR:
return False
near_match_floor = max(ANSWER_SEQUENCE_MATCH_FLOOR, best_similarity - 0.08)
return any(
float(match[0]) >= near_match_floor
for match in matches[1:ANSWER_SEQUENCE_VARIATION_MATCH_LIMIT]
)
def _answer_sequence_match_has_continuation(
self,
match: tuple[float, int, int],
generated_tokens: list[str],
) -> bool:
if (
self.embedding_model is None
or self.answer_sequence_tokens is None
or not generated_tokens
):
return False
similarity, sequence_index, _ = match
if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens):
return False
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
if not generated_ids:
return False
row = self.answer_sequence_tokens[sequence_index]
token_ids = [
int(value)
for value in (row.tolist() if hasattr(row, "tolist") else row)
if int(value) >= 0
]
if not token_ids:
return False
next_token_id = self._next_sequence_token_id(token_ids, generated_ids)
if next_token_id is None:
return False
token = self.embedding_model.id_to_token[next_token_id]
return self._allowed_answer_sequence_token(token, generated_tokens)
def _allowed_answer_sequence_token(
self,
token: str,
generated_tokens: list[str],
) -> bool:
assert self.tokenizer is not None
if token == self.tokenizer.unk_token:
return False
if token in self.tokenizer.special_tokens:
return self._allowed_generation_token(token, generated_tokens)
return True
def _should_relax_answer_sequence_memory(
self,
matches: list[tuple[float, int, int]],
answer_sequence_prior: Sequence[float],
*,
generated_tokens: list[str],
temperature: float,
) -> bool:
if temperature < ANSWER_SEQUENCE_CREATIVE_TEMPERATURE or not matches:
return False
if self._is_inside_tool_protocol_continuation(generated_tokens):
return False
if self._answer_sequence_prior_prefers_tool_protocol(answer_sequence_prior):
return False
return True
def _answer_sequence_prior_prefers_tool_protocol(
self,
answer_sequence_prior: Sequence[float],
) -> bool:
if self.embedding_model is None or not answer_sequence_prior:
return False
best_index = -1
best_value = 0.0
for index, value in enumerate(answer_sequence_prior):
if value > best_value:
best_index = index
best_value = float(value)
return (
best_index >= 0
and best_index < len(self.embedding_model.id_to_token)
and best_value > 0.0
and self.embedding_model.id_to_token[best_index] in TOOL_PROTOCOL_TOKENS
)
@staticmethod
def _answer_start_blend_weights(
*,
answer_sequence_match_confidence: float,
temperature: float = 0.0,
) -> dict[str, float]:
if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE:
return {
"prompt_answer_start": 0.46,
"prompt_answer": 0.24,
"answer_sequence": 0.10,
"answer_start": 0.20,
}
if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR:
return {
"prompt_answer_start": 0.35,
"prompt_answer": 0.10,
"answer_sequence": 0.45,
"answer_start": 0.10,
}
if answer_sequence_match_confidence >= 0.40:
return {
"prompt_answer_start": 0.25,
"prompt_answer": 0.12,
"answer_sequence": 0.53,
"answer_start": 0.10,
}
return {
"prompt_answer_start": 0.08,
"prompt_answer": 0.10,
"answer_sequence": 0.02,
"answer_start": 0.80,
}
def _score_next_token_from_tokens(
self,
context_tokens: list[str],
*,
top_k: int = 5,
include_trace: bool = True,
) -> tuple[dict[str, float], dict[str, object]]:
decode_state = self._build_decode_state(context_tokens)
return self._score_next_token_from_state(
decode_state,
top_k=top_k,
include_trace=include_trace,
)
def _score_next_token_from_state(
self,
decode_state: DecodeState,
*,
top_k: int = 5,
include_trace: bool = True,
generated_tokens: list[str] | None = None,
temperature: float = 0.0,
avoid_token_sequences: Sequence[Sequence[str]] | None = None,
) -> tuple[dict[str, float], dict[str, object]]:
assert self.embedding_model is not None
assert self.readout_weights is not None
generated_tokens = generated_tokens or []
state = self._masked_decode_state(decode_state)
logits = self._apply_readout_fast(state)
base_probabilities = self._calibrated_softmax(logits)
if decode_state.answer_matches is None:
decode_state.answer_matches = self._score_answer_matches(
decode_state.answer_anchor_state,
limit=max(ANSWER_TOP_K, top_k) if include_trace else ANSWER_TOP_K,
)
answer_matches = decode_state.answer_matches
if decode_state.answer_start_matches is None:
decode_state.answer_start_matches = self._score_answer_start_matches(
decode_state.answer_anchor_state,
limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K,
)
answer_start_matches = decode_state.answer_start_matches
if decode_state.answer_sequence_matches is None:
decode_state.answer_sequence_matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K,
)
answer_sequence_matches = self._filter_avoided_answer_sequence_matches(
decode_state.answer_sequence_matches,
avoid_token_sequences,
)
if not answer_start_matches and answer_sequence_matches:
answer_start_matches = self._answer_start_matches_from_sequences(
answer_sequence_matches
)
decode_state.answer_start_matches = answer_start_matches
answer_prior = self._answer_prior_from_matches(answer_matches, generated_tokens)
answer_start_prior = self._answer_prior_from_matches(answer_start_matches, generated_tokens)
answer_sequence_prior = self._answer_sequence_prior_from_matches(
answer_sequence_matches,
generated_tokens,
temperature=temperature,
)
answer_sequence_confidence = max(answer_sequence_prior) if answer_sequence_prior else 0.0
answer_sequence_match_confidence = (
answer_sequence_matches[0][0] if answer_sequence_matches else 0.0
)
answer_start_confidence = answer_start_matches[0][0] if answer_start_matches else 0.0
prompt_copy_is_distinctive = (
not generated_tokens
and self._prompt_copy_evidence_is_distinctive(decode_state.context_tokens)
)
answer_memory_confident = self._answer_memory_is_confident(
answer_sequence_match_confidence=answer_sequence_match_confidence,
answer_start_confidence=answer_start_confidence,
generated_count=len(generated_tokens),
)
if prompt_copy_is_distinctive and not answer_sequence_matches:
answer_memory_confident = False
has_answer_sequence_prior = any(value > 0.0 for value in answer_sequence_prior)
if not answer_memory_confident:
zero_prior = [0.0 for _ in self.embedding_model.id_to_token]
answer_prior = zero_prior
answer_start_prior = zero_prior
answer_sequence_prior = zero_prior
answer_sequence_confidence = 0.0
has_answer_sequence_prior = False
answer_locked = self._answer_sequence_should_lock(
answer_sequence_confidence=answer_sequence_confidence,
answer_sequence_match_confidence=answer_sequence_match_confidence,
has_answer_sequence_prior=has_answer_sequence_prior,
) or (
bool(generated_tokens)
and temperature < ANSWER_SEQUENCE_CREATIVE_TEMPERATURE
and self._answer_sequence_has_continuation(
generated_tokens,
answer_sequence_matches,
)
)
if self._should_relax_answer_sequence_memory(
answer_sequence_matches,
answer_sequence_prior,
generated_tokens=generated_tokens,
temperature=temperature,
):
answer_locked = False
if decode_state.prompt_answer_prior is None:
decode_state.prompt_answer_prior = self._prompt_answer_readout_prior(
decode_state.answer_anchor_state,
start=False,
)
prompt_answer_prior = decode_state.prompt_answer_prior
prompt_answer_start_prior = (
decode_state.prompt_answer_start_prior
if not generated_tokens
else [0.0 for _ in self.embedding_model.id_to_token]
)
if not generated_tokens and prompt_answer_start_prior is None:
decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior(
decode_state.answer_anchor_state,
start=True,
)
prompt_answer_start_prior = decode_state.prompt_answer_start_prior
prompt_start_readout_confident = (
not generated_tokens
and prompt_answer_start_prior is not None
and self._prompt_start_readout_is_confident(prompt_answer_start_prior)
)
prompt_readout_supported = answer_memory_confident and (
answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR
or answer_start_confidence >= ANSWER_START_CONFIDENCE_FLOOR
)
if prompt_start_readout_confident:
prompt_readout_supported = True
if not prompt_readout_supported:
prompt_answer_prior = [0.0 for _ in self.embedding_model.id_to_token]
prompt_answer_start_prior = [0.0 for _ in self.embedding_model.id_to_token]
use_answer_start = (
not generated_tokens
and (
any(value > 0.0 for value in answer_start_prior)
or any(value > 0.0 for value in prompt_answer_start_prior)
)
)
if answer_locked:
locked_matches = self._locked_answer_sequence_matches(
answer_sequence_matches,
generated_tokens=generated_tokens,
temperature=temperature,
answer_sequence_confidence=answer_sequence_confidence,
answer_sequence_match_confidence=answer_sequence_match_confidence,
)
answer_sequence_prior = self._answer_sequence_prior_from_matches(
locked_matches,
generated_tokens,
temperature=temperature,
)
answer_prior = answer_sequence_prior
elif use_answer_start:
start_blend = self._answer_start_blend_weights(
answer_sequence_match_confidence=answer_sequence_match_confidence,
temperature=temperature,
)
answer_prior = self._weighted_prior_sum(
[
(start_blend["prompt_answer_start"], prompt_answer_start_prior),
(start_blend["prompt_answer"], prompt_answer_prior),
(start_blend["answer_sequence"], answer_sequence_prior),
(start_blend["answer_start"], answer_start_prior),
],
)
elif any(value > 0.0 for value in answer_sequence_prior):
sequence_weight = (
0.10
if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE
else 0.30
)
answer_prior = self._weighted_prior_sum(
[
(0.55, prompt_answer_prior),
(sequence_weight, answer_sequence_prior),
(0.20, answer_prior),
],
)
elif any(value > 0.0 for value in prompt_answer_prior):
answer_prior = self._weighted_prior_sum(
[
(0.65, prompt_answer_prior),
(0.35, answer_prior),
],
)
answer_guided = (
max(answer_prior) >= 0.08
if answer_prior
else False
)
associative_matches = (
[]
if use_answer_start or answer_guided
else self._score_associative_matches(
state,
limit=max(ASSOCIATIVE_TOP_K, top_k) if include_trace else ASSOCIATIVE_TOP_K,
)
)
associative_prior = (
[0.0 for _ in self.embedding_model.id_to_token]
if use_answer_start or answer_guided
else self._associative_prior_from_matches(associative_matches)
)
transition_prior, transition_order = self._transition_prior_with_order(decode_state.context_tokens)
copy_prior = self._copy_prior(decode_state.context_tokens)
source_evidence_prior = self._source_evidence_prior(
decode_state.context_tokens,
generated_tokens,
)
preference_prior = self._preference_prior()
probabilities, blend_weights = self._blend_probabilities(
base_probabilities,
answer_prior,
associative_prior,
transition_prior,
copy_prior,
source_evidence_prior,
preference_prior,
transition_order=transition_order,
generated_count=len(generated_tokens),
answer_locked=answer_locked,
answer_guided_start=use_answer_start,
copy_guided_start=prompt_copy_is_distinctive,
)
probabilities = self._focus_answer_start_probabilities(
probabilities,
answer_sequence_prior,
generated_tokens=generated_tokens,
answer_memory_confident=answer_memory_confident,
has_answer_sequence_prior=has_answer_sequence_prior,
sequence_focus_allowed=answer_sequence_match_confidence >= 0.40 or answer_locked,
temperature=temperature,
)
distribution = {
token: probabilities[index]
for index, token in enumerate(self.embedding_model.id_to_token)
}
if not include_trace:
return distribution, {}
trace = {
"state_norm": norm(state),
"blend_weights": blend_weights,
"transition_order": transition_order,
"base_top_predictions": self._top_entries_from_vector(base_probabilities, top_k),
"answer_top_predictions": self._top_entries_from_vector(answer_prior, top_k),
"prompt_answer_top_predictions": self._top_entries_from_vector(prompt_answer_prior, top_k),
"prompt_answer_start_top_predictions": self._top_entries_from_vector(prompt_answer_start_prior, top_k),
"answer_start_top_predictions": self._top_entries_from_vector(answer_start_prior, top_k),
"answer_sequence_top_predictions": self._top_entries_from_vector(answer_sequence_prior, top_k),
"associative_top_predictions": self._top_entries_from_vector(associative_prior, top_k),
"transition_top_predictions": self._top_entries_from_vector(transition_prior, top_k),
"copy_top_predictions": self._top_entries_from_vector(copy_prior, top_k),
"source_evidence_top_predictions": self._top_entries_from_vector(source_evidence_prior, top_k),
"preference_top_predictions": self._top_entries_from_vector(preference_prior, top_k),
"final_top_predictions": self._top_entries_from_vector(probabilities, top_k),
"associative_matches": [
{
"example_index": example_index,
"similarity": similarity,
**self._token_entry(token_id, similarity),
}
for similarity, token_id, example_index in associative_matches[:top_k]
],
"answer_matches": [
{
"example_index": example_index,
"similarity": similarity,
**self._token_entry(token_id, similarity),
}
for similarity, token_id, example_index in answer_matches[:top_k]
],
"answer_start_matches": [
{
"example_index": example_index,
"similarity": similarity,
**self._token_entry(token_id, similarity),
}
for similarity, token_id, example_index in answer_start_matches[:top_k]
],
"answer_sequence_matches": [
{
"example_index": example_index,
"similarity": similarity,
}
for similarity, _, example_index in answer_sequence_matches[:top_k]
],
"reasoning_summary": self._build_reasoning_summary(
transition_order,
blend_weights,
),
}
return distribution, trace
def _score_next_token_array_from_state(
self,
decode_state: DecodeState,
*,
include_associative: bool,
generated_tokens: list[str] | None = None,
temperature: float = 0.0,
avoid_token_sequences: Sequence[Sequence[str]] | None = None,
) -> tuple[object, dict[str, float]]:
assert np is not None
assert self.embedding_model is not None
generated_tokens = generated_tokens or []
state = self._masked_decode_state_array(decode_state)
logits = self._apply_readout_array(state)
base_probabilities = self._calibrated_softmax_array(logits)
if decode_state.answer_matches is None:
decode_state.answer_matches = self._score_answer_matches(decode_state.answer_anchor_state)
answer_prior = np.asarray(
self._answer_prior_from_matches(
decode_state.answer_matches,
generated_tokens,
),
dtype=np.float64,
)
if decode_state.answer_sequence_matches is None:
decode_state.answer_sequence_matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
)
answer_sequence_matches = self._filter_avoided_answer_sequence_matches(
decode_state.answer_sequence_matches,
avoid_token_sequences,
)
if not decode_state.answer_start_matches and answer_sequence_matches:
decode_state.answer_start_matches = self._answer_start_matches_from_sequences(
answer_sequence_matches
)
answer_sequence_prior = np.asarray(
self._answer_sequence_prior_from_matches(
answer_sequence_matches,
generated_tokens,
temperature=temperature,
),
dtype=np.float64,
)
answer_sequence_confidence = (
float(answer_sequence_prior.max()) if answer_sequence_prior.size else 0.0
)
answer_sequence_match_confidence = (
answer_sequence_matches[0][0] if answer_sequence_matches else 0.0
)
if not generated_tokens and decode_state.answer_start_matches is None:
decode_state.answer_start_matches = self._score_answer_start_matches(
decode_state.answer_anchor_state
)
answer_start_confidence = (
decode_state.answer_start_matches[0][0]
if not generated_tokens and decode_state.answer_start_matches
else 0.0
)
prompt_copy_is_distinctive = (
not generated_tokens
and self._prompt_copy_evidence_is_distinctive(decode_state.context_tokens)
)
answer_memory_confident = self._answer_memory_is_confident(
answer_sequence_match_confidence=answer_sequence_match_confidence,
answer_start_confidence=answer_start_confidence,
generated_count=len(generated_tokens),
)
if prompt_copy_is_distinctive and not answer_sequence_matches:
answer_memory_confident = False
has_answer_sequence_prior = bool(np.any(answer_sequence_prior > 0.0))
if not answer_memory_confident:
answer_prior = np.zeros_like(base_probabilities)
answer_sequence_prior = np.zeros_like(base_probabilities)
answer_sequence_confidence = 0.0
has_answer_sequence_prior = False
answer_locked = self._answer_sequence_should_lock(
answer_sequence_confidence=answer_sequence_confidence,
answer_sequence_match_confidence=answer_sequence_match_confidence,
has_answer_sequence_prior=has_answer_sequence_prior,
) or (
bool(generated_tokens)
and temperature < ANSWER_SEQUENCE_CREATIVE_TEMPERATURE
and self._answer_sequence_has_continuation(
generated_tokens,
answer_sequence_matches,
)
)
if self._should_relax_answer_sequence_memory(
answer_sequence_matches,
answer_sequence_prior.tolist(),
generated_tokens=generated_tokens,
temperature=temperature,
):
answer_locked = False
if decode_state.prompt_answer_prior is None:
decode_state.prompt_answer_prior = self._prompt_answer_readout_prior_array(
decode_state.answer_anchor_state,
start=False,
)
prompt_answer_prior = decode_state.prompt_answer_prior
prompt_answer_start_prior = np.zeros_like(base_probabilities)
use_answer_start = False
if answer_locked:
locked_matches = self._locked_answer_sequence_matches(
answer_sequence_matches,
generated_tokens=generated_tokens,
temperature=temperature,
answer_sequence_confidence=answer_sequence_confidence,
answer_sequence_match_confidence=answer_sequence_match_confidence,
)
answer_sequence_prior = np.asarray(
self._answer_sequence_prior_from_matches(
locked_matches,
generated_tokens,
temperature=temperature,
),
dtype=np.float64,
)
answer_prior = answer_sequence_prior
elif not generated_tokens:
if decode_state.prompt_answer_start_prior is None:
decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior_array(
decode_state.answer_anchor_state,
start=True,
)
prompt_answer_start_prior = (
decode_state.prompt_answer_start_prior
if decode_state.prompt_answer_start_prior is not None
else np.zeros_like(base_probabilities)
)
prompt_start_readout_confident = self._prompt_start_readout_is_confident(
prompt_answer_start_prior
)
prompt_readout_supported = answer_memory_confident and (
answer_sequence_match_confidence >= ANSWER_SEQUENCE_MATCH_FLOOR
or answer_start_confidence >= ANSWER_START_CONFIDENCE_FLOOR
)
if prompt_start_readout_confident:
prompt_readout_supported = True
if not prompt_readout_supported:
prompt_answer_prior = np.zeros_like(base_probabilities)
prompt_answer_start_prior = np.zeros_like(base_probabilities)
answer_start_prior = np.asarray(
self._answer_prior_from_matches(
decode_state.answer_start_matches,
generated_tokens,
),
dtype=np.float64,
)
if not answer_memory_confident:
answer_start_prior = np.zeros_like(base_probabilities)
if np.any(answer_start_prior > 0.0) or np.any(prompt_answer_start_prior > 0.0):
start_blend = self._answer_start_blend_weights(
answer_sequence_match_confidence=answer_sequence_match_confidence,
temperature=temperature,
)
answer_prior = self._weighted_prior_sum_array(
[
(start_blend["prompt_answer_start"], prompt_answer_start_prior),
(start_blend["prompt_answer"], prompt_answer_prior),
(start_blend["answer_sequence"], answer_sequence_prior),
(start_blend["answer_start"], answer_start_prior),
],
)
use_answer_start = True
if answer_locked:
answer_prior = answer_sequence_prior
elif not use_answer_start and np.any(answer_sequence_prior > 0.0):
sequence_weight = (
0.10
if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE
else 0.30
)
answer_prior = self._weighted_prior_sum_array(
[
(0.55, prompt_answer_prior),
(sequence_weight, answer_sequence_prior),
(0.20, answer_prior),
],
)
elif not use_answer_start and np.any(prompt_answer_prior > 0.0):
answer_prior = self._weighted_prior_sum_array(
[
(0.65, prompt_answer_prior),
(0.35, answer_prior),
],
)
answer_guided = bool(answer_prior.size and float(np.max(answer_prior)) >= 0.08)
if include_associative and not use_answer_start and not answer_guided:
associative_prior = np.asarray(
self._associative_prior_from_matches(
self._score_associative_matches(state)
),
dtype=np.float64,
)
else:
associative_prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
transition_prior, transition_order = self._transition_prior_array_with_order(
decode_state.context_tokens
)
copy_prior = self._copy_prior_array(decode_state.context_tokens)
source_evidence_prior = self._source_evidence_prior_array(
decode_state.context_tokens,
generated_tokens,
)
preference_prior = self._preference_prior_array()
probabilities, blend_weights = self._blend_probability_arrays(
base_probabilities,
answer_prior,
associative_prior,
transition_prior,
copy_prior,
source_evidence_prior,
preference_prior,
transition_order=transition_order,
generated_count=len(generated_tokens),
answer_locked=answer_locked,
answer_guided_start=use_answer_start,
)
probabilities = self._focus_answer_start_probability_array(
probabilities,
answer_sequence_prior,
generated_tokens=generated_tokens,
answer_memory_confident=answer_memory_confident,
has_answer_sequence_prior=has_answer_sequence_prior,
sequence_focus_allowed=answer_sequence_match_confidence >= 0.40 or answer_locked,
temperature=temperature,
)
return probabilities, blend_weights
@staticmethod
def _focus_answer_start_probabilities(
probabilities: Vector,
answer_sequence_prior: Vector,
*,
generated_tokens: list[str],
answer_memory_confident: bool,
has_answer_sequence_prior: bool,
sequence_focus_allowed: bool | None = None,
temperature: float = 0.0,
) -> Vector:
if sequence_focus_allowed is None:
sequence_focus_allowed = has_answer_sequence_prior
if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE:
return probabilities
if (
generated_tokens
or not answer_memory_confident
or not has_answer_sequence_prior
or not sequence_focus_allowed
):
return probabilities
if not probabilities or not answer_sequence_prior:
return probabilities
focused = [
probability if index < len(answer_sequence_prior) and answer_sequence_prior[index] > 0.0 else probability * 0.02
for index, probability in enumerate(probabilities)
]
total = sum(focused)
if total <= 0.0:
return probabilities
return [value / total for value in focused]
@staticmethod
def _focus_answer_start_probability_array(
probabilities: object,
answer_sequence_prior: object,
*,
generated_tokens: list[str],
answer_memory_confident: bool,
has_answer_sequence_prior: bool,
sequence_focus_allowed: bool | None = None,
temperature: float = 0.0,
) -> object:
if sequence_focus_allowed is None:
sequence_focus_allowed = has_answer_sequence_prior
if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE:
return probabilities
if (
np is None
or generated_tokens
or not answer_memory_confident
or not has_answer_sequence_prior
or not sequence_focus_allowed
):
return probabilities
values = np.asarray(probabilities, dtype=np.float64)
prior = np.asarray(answer_sequence_prior, dtype=np.float64)
if values.size == 0 or prior.size != values.size or not np.any(prior > 0.0):
return probabilities
focused = values.copy()
focused[prior <= 0.0] *= 0.02
total = float(focused.sum())
if total <= 0.0:
return probabilities
return focused / total
def _calibrated_softmax(
self,
logits: Vector,
*,
scale: float = READOUT_LOGIT_ZSCORE_SCALE,
) -> Vector:
if np is not None:
return self._calibrated_softmax_array(
np.asarray(logits, dtype=np.float64),
scale=scale,
).tolist()
if not logits:
return []
center = mean(logits)
variance = mean([(value - center) * (value - center) for value in logits])
spread = variance**0.5
if spread <= 1e-12:
return softmax(logits)
calibrated = [
max(-20.0, min(20.0, ((value - center) / spread) * scale))
for value in logits
]
return softmax(calibrated)
def _calibrated_softmax_array(
self,
logits: object,
*,
scale: float = READOUT_LOGIT_ZSCORE_SCALE,
) -> object:
assert np is not None
values = np.asarray(logits, dtype=np.float64)
if values.size == 0:
return values
spread = float(values.std())
if spread > 1e-12:
values = ((values - float(values.mean())) / spread) * scale
values = np.clip(values, -20.0, 20.0)
else:
values = values - float(values.max())
values = values - float(values.max())
exponentials = np.exp(values)
total = float(exponentials.sum())
if total <= 0.0:
return np.full(values.shape, 1.0 / max(1, values.size), dtype=np.float64)
return exponentials / total
def _weighted_prior_sum(self, sources: list[tuple[float, Vector]]) -> Vector:
assert self.embedding_model is not None
active_sources = [
(weight, vector)
for weight, vector in sources
if weight > 0.0 and any(value > 0.0 for value in vector)
]
if not active_sources:
return [0.0 for _ in self.embedding_model.id_to_token]
total_weight = sum(weight for weight, _ in active_sources)
merged = [0.0 for _ in self.embedding_model.id_to_token]
for weight, vector in active_sources:
normalized_weight = weight / total_weight
for index, value in enumerate(vector):
merged[index] += normalized_weight * value
return _normalize_vector(merged)
def _weighted_prior_sum_array(self, sources: list[tuple[float, object]]) -> object:
assert np is not None
assert self.embedding_model is not None
active_sources = [
(weight, np.asarray(vector, dtype=np.float64))
for weight, vector in sources
if weight > 0.0 and np.any(np.asarray(vector, dtype=np.float64) > 0.0)
]
if not active_sources:
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
total_weight = sum(weight for weight, _ in active_sources)
merged = np.zeros_like(active_sources[0][1], dtype=np.float64)
for weight, vector in active_sources:
merged += (weight / total_weight) * vector
total = float(merged.sum())
if total > 0.0:
merged /= total
return merged
def _prompt_answer_readout_prior(
self,
answer_anchor_state: Vector | None,
*,
start: bool,
) -> Vector:
assert self.embedding_model is not None
if answer_anchor_state is None:
return [0.0 for _ in self.embedding_model.id_to_token]
weights = self.prompt_answer_start_weights if start else self.prompt_answer_weights
bias = self.prompt_answer_start_bias if start else self.prompt_answer_bias
if np is not None:
return self._prompt_answer_readout_prior_array(
answer_anchor_state,
start=start,
).tolist()
if not weights:
return [0.0 for _ in self.embedding_model.id_to_token]
state = self._center_state_vector(self._masked_combined_state(answer_anchor_state))
logits = apply_readout(weights, state)
if bias:
logits = [value + bias[index] for index, value in enumerate(logits)]
return self._calibrated_softmax(
logits,
scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE,
)
def _prompt_answer_readout_prior_array(
self,
answer_anchor_state: Vector | None,
*,
start: bool,
) -> object:
assert np is not None
assert self.embedding_model is not None
if answer_anchor_state is None:
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
weights = (
self.prompt_answer_start_weights_array
if start
else self.prompt_answer_weights_array
)
bias = self.prompt_answer_start_bias_array if start else self.prompt_answer_bias_array
if weights is None:
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
state_array = self._center_state_array(
self._masked_combined_state_array(answer_anchor_state)
)
logits = weights @ state_array
if bias is not None and bias.shape == logits.shape:
logits = logits + bias
return self._calibrated_softmax_array(
logits,
scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE,
)
def save(self, path: str | Path) -> None:
self._require_fit()
assert self.tokenizer is not None
assert self.embedding_model is not None
assert self.ternary_mask is not None
assert self.readout_weights is not None
assert self.associative_keys is not None
assert self.associative_values is not None
assert self.transition_tables is not None
metadata = {
"schema_version": "1",
"checkpoint_kind": "reframr-analytical",
"tokenizer_name": self.tokenizer.name,
"config": json.dumps(self.config.to_dict(), separators=(",", ":")),
"tokenizer": json.dumps(self.tokenizer.to_dict(), separators=(",", ":")),
"embedding_id_to_token": json.dumps(self.embedding_model.id_to_token, separators=(",", ":")),
"tokenizer_vocab_size": str(self.tokenizer.vocab_size),
"transition_table_format": "tensor-v1",
}
self._refresh_answer_fingerprint_hashes()
if np is not None:
self._refresh_numeric_caches()
transition_tensors = self._transition_table_tensors()
tensors = {
"embedding_table": self.embedding_model.embeddings,
"ternary_scale": [self.ternary_scale],
"ternary_mask": self.ternary_mask,
"readout_weights": self.readout_weights,
"readout_bias": self.readout_bias
or [0.0 for _ in self.embedding_model.id_to_token],
"prompt_answer_weights": self.prompt_answer_weights
if self.prompt_answer_weights is not None
else [],
"prompt_answer_bias": self.prompt_answer_bias
or [0.0 for _ in self.embedding_model.id_to_token],
"prompt_answer_start_weights": self.prompt_answer_start_weights
if self.prompt_answer_start_weights is not None
else [],
"prompt_answer_start_bias": self.prompt_answer_start_bias
or [0.0 for _ in self.embedding_model.id_to_token],
"trace_token_weights": self.trace_token_weights
or [1.0 for _ in self.embedding_model.id_to_token],
"preference_bias": self.preference_bias
or [0.0 for _ in self.embedding_model.id_to_token],
"state_offset": self.state_offset
or [0.0 for _ in range(self._combined_state_width())],
"associative_keys": self.associative_keys,
"associative_key_norms": self.associative_key_norms_array
if self.associative_key_norms_array is not None
else self.associative_key_norms or [],
"associative_values": self.associative_values,
"answer_keys": self.answer_keys if self.answer_keys is not None else [],
"answer_key_norms": self.answer_key_norms_array
if self.answer_key_norms_array is not None
else self.answer_key_norms or [],
"answer_similarity_keys": self.answer_similarity_keys_array
if self.answer_similarity_keys_array is not None
else [],
"answer_similarity_key_norms": self.answer_similarity_key_norms_array
if self.answer_similarity_key_norms_array is not None
else [],
"answer_values": self.answer_values if self.answer_values is not None else [],
"answer_start_keys": self.answer_start_keys if self.answer_start_keys is not None else [],
"answer_start_key_norms": self.answer_start_key_norms_array
if self.answer_start_key_norms_array is not None
else self.answer_start_key_norms or [],
"answer_start_similarity_keys": self.answer_start_similarity_keys_array
if self.answer_start_similarity_keys_array is not None
else [],
"answer_start_similarity_key_norms": self.answer_start_similarity_key_norms_array
if self.answer_start_similarity_key_norms_array is not None
else [],
"answer_start_values": self.answer_start_values if self.answer_start_values is not None else [],
"answer_sequence_keys": self.answer_sequence_keys if self.answer_sequence_keys is not None else [],
"answer_sequence_key_norms": self.answer_sequence_key_norms_array
if self.answer_sequence_key_norms_array is not None
else self.answer_sequence_key_norms or [],
"answer_sequence_similarity_keys": self.answer_sequence_similarity_keys_array
if self.answer_sequence_similarity_keys_array is not None
else [],
"answer_sequence_similarity_key_norms": self.answer_sequence_similarity_key_norms_array
if self.answer_sequence_similarity_key_norms_array is not None
else [],
"answer_sequence_prompt_tokens": self.answer_sequence_prompt_tokens if self.answer_sequence_prompt_tokens is not None else [],
"answer_sequence_tokens": self.answer_sequence_tokens if self.answer_sequence_tokens is not None else [],
"answer_fingerprint_hashes": self._answer_fingerprint_tensor(),
**transition_tensors,
}
write_safetensor_file(path, tensors, metadata=metadata)
@classmethod
def load(cls, path: str | Path) -> "ReframrModel":
checkpoint_path = Path(path)
checkpoint = read_safetensor_file(
checkpoint_path,
arrays=np is not None and checkpoint_path.stat().st_size > 10_000_000,
)
metadata = checkpoint.metadata
config = ReframrConfig.from_dict(json.loads(metadata["config"]))
model = cls(config)
model.tokenizer = NativeTokenizer.from_dict(json.loads(metadata["tokenizer"]))
id_to_token = [str(token) for token in json.loads(metadata["embedding_id_to_token"])]
embedding_table = checkpoint.tensors["embedding_table"]
if np is not None and hasattr(embedding_table, "shape"):
embeddings = embedding_table.astype(RUNTIME_ARRAY_DTYPE, copy=False)
else:
embeddings = [[float(value) for value in row] for row in embedding_table]
model.embedding_model = EmbeddingModel(
token_to_id={token: index for index, token in enumerate(id_to_token)},
id_to_token=id_to_token,
embeddings=embeddings,
ppmi_matrix=[],
)
model.memory_units = [
AnalyticalMemoryUnit(model.config.state_dim, timescale)
for timescale in model.config.timescales
]
model.ternary_scale = float(checkpoint.tensors["ternary_scale"][0])
model.ternary_mask = [int(value) for value in checkpoint.tensors["ternary_mask"]]
readout_tensor = checkpoint.tensors["readout_weights"]
model.readout_weights = (
readout_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False)
if np is not None and hasattr(readout_tensor, "shape")
else [[float(value) for value in row] for row in readout_tensor]
)
readout_bias_tensor = checkpoint.tensors.get("readout_bias", [])
model.readout_bias = [
float(value) for value in (
readout_bias_tensor.tolist()
if hasattr(readout_bias_tensor, "tolist")
else readout_bias_tensor
)
]
if not model.readout_bias:
model.readout_bias = [0.0 for _ in id_to_token]
prompt_answer_tensor = checkpoint.tensors.get("prompt_answer_weights", [])
model.prompt_answer_weights = (
prompt_answer_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False)
if np is not None
and hasattr(prompt_answer_tensor, "shape")
and len(prompt_answer_tensor.shape) == 2
else [[float(value) for value in row] for row in prompt_answer_tensor]
)
prompt_answer_bias_tensor = checkpoint.tensors.get("prompt_answer_bias", [])
model.prompt_answer_bias = [
float(value) for value in (
prompt_answer_bias_tensor.tolist()
if hasattr(prompt_answer_bias_tensor, "tolist")
else prompt_answer_bias_tensor
)
]
if not model.prompt_answer_bias:
model.prompt_answer_bias = [0.0 for _ in id_to_token]
prompt_answer_start_tensor = checkpoint.tensors.get("prompt_answer_start_weights", [])
model.prompt_answer_start_weights = (
prompt_answer_start_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False)
if np is not None
and hasattr(prompt_answer_start_tensor, "shape")
and len(prompt_answer_start_tensor.shape) == 2
else [[float(value) for value in row] for row in prompt_answer_start_tensor]
)
prompt_answer_start_bias_tensor = checkpoint.tensors.get("prompt_answer_start_bias", [])
model.prompt_answer_start_bias = [
float(value) for value in (
prompt_answer_start_bias_tensor.tolist()
if hasattr(prompt_answer_start_bias_tensor, "tolist")
else prompt_answer_start_bias_tensor
)
]
if not model.prompt_answer_start_bias:
model.prompt_answer_start_bias = [0.0 for _ in id_to_token]
trace_weight_tensor = checkpoint.tensors.get("trace_token_weights", [])
model.trace_token_weights = [
float(value) for value in (
trace_weight_tensor.tolist()
if hasattr(trace_weight_tensor, "tolist")
else trace_weight_tensor
)
]
if not model.trace_token_weights:
model.trace_token_weights = [
1.0 if token in TOOL_PROTOCOL_TOKENS else 0.0 if token in model.tokenizer.special_tokens else 1.0
for token in id_to_token
]
preference_bias_tensor = checkpoint.tensors.get("preference_bias", [])
model.preference_bias = [
float(value) for value in (
preference_bias_tensor.tolist()
if hasattr(preference_bias_tensor, "tolist")
else preference_bias_tensor
)
]
if not model.preference_bias:
model.preference_bias = [0.0 for _ in id_to_token]
state_offset_tensor = checkpoint.tensors.get("state_offset", [])
model.state_offset = [
float(value) for value in (
state_offset_tensor.tolist()
if hasattr(state_offset_tensor, "tolist")
else state_offset_tensor
)
]
if not model.state_offset:
model.state_offset = [0.0 for _ in range(model._combined_state_width())]
def _runtime_vector_tensor(name: str) -> object | None:
tensor = checkpoint.tensors.get(name, [])
if np is not None and hasattr(tensor, "shape"):
if len(tensor.shape) == 1 and int(tensor.shape[0]) > 0:
return tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False)
return None
values = tensor.tolist() if hasattr(tensor, "tolist") else tensor
return [float(value) for value in values] if values else None
def _runtime_matrix_tensor(name: str) -> object | None:
tensor = checkpoint.tensors.get(name, [])
if (
np is not None
and hasattr(tensor, "shape")
and len(tensor.shape) == 2
and int(tensor.shape[0]) > 0
):
return tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False)
return None
associative_tensor = checkpoint.tensors.get("associative_keys", [])
model.associative_keys = (
associative_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False)
if np is not None and hasattr(associative_tensor, "shape")
else [[float(value) for value in row] for row in associative_tensor]
)
cached_associative_key_norms = _runtime_vector_tensor("associative_key_norms")
if cached_associative_key_norms is not None:
model.associative_key_norms = cached_associative_key_norms
elif np is not None and hasattr(model.associative_keys, "shape"):
model.associative_key_norms = None
else:
model.associative_key_norms = [norm(key) for key in model.associative_keys]
raw_associative_values = checkpoint.tensors.get("associative_values", [])
model.associative_values = [
int(value) for value in (
raw_associative_values.tolist()
if hasattr(raw_associative_values, "tolist")
else raw_associative_values
)
]
answer_tensor = checkpoint.tensors.get("answer_keys", [])
if np is not None and hasattr(answer_tensor, "shape"):
model.answer_keys = (
answer_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False)
if len(answer_tensor.shape) == 2
else []
)
else:
model.answer_keys = [[float(value) for value in row] for row in answer_tensor]
if (
np is not None
and hasattr(model.answer_keys, "shape")
and len(model.answer_keys.shape) == 2
):
model.answer_key_norms = _runtime_vector_tensor("answer_key_norms")
else:
model.answer_key_norms = (
_runtime_vector_tensor("answer_key_norms")
or [norm(key) for key in model.answer_keys]
)
raw_answer_values = checkpoint.tensors.get("answer_values", [])
model.answer_values = [
int(value) for value in (
raw_answer_values.tolist()
if hasattr(raw_answer_values, "tolist")
else raw_answer_values
)
]
answer_start_tensor = checkpoint.tensors.get("answer_start_keys", [])
if np is not None and hasattr(answer_start_tensor, "shape"):
model.answer_start_keys = (
answer_start_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False)
if len(answer_start_tensor.shape) == 2
else []
)
else:
model.answer_start_keys = [
[float(value) for value in row] for row in answer_start_tensor
]
if (
np is not None
and hasattr(model.answer_start_keys, "shape")
and len(model.answer_start_keys.shape) == 2
):
model.answer_start_key_norms = _runtime_vector_tensor("answer_start_key_norms")
else:
model.answer_start_key_norms = (
_runtime_vector_tensor("answer_start_key_norms")
or [norm(key) for key in model.answer_start_keys]
)
raw_answer_start_values = checkpoint.tensors.get("answer_start_values", [])
model.answer_start_values = [
int(value) for value in (
raw_answer_start_values.tolist()
if hasattr(raw_answer_start_values, "tolist")
else raw_answer_start_values
)
]
answer_sequence_tensor = checkpoint.tensors.get("answer_sequence_keys", [])
if np is not None and hasattr(answer_sequence_tensor, "shape"):
model.answer_sequence_keys = (
answer_sequence_tensor.astype(RUNTIME_ARRAY_DTYPE, copy=False)
if len(answer_sequence_tensor.shape) == 2
else []
)
else:
model.answer_sequence_keys = [
[float(value) for value in row] for row in answer_sequence_tensor
]
if (
np is not None
and hasattr(model.answer_sequence_keys, "shape")
and len(model.answer_sequence_keys.shape) == 2
):
model.answer_sequence_key_norms = _runtime_vector_tensor("answer_sequence_key_norms")
else:
model.answer_sequence_key_norms = (
_runtime_vector_tensor("answer_sequence_key_norms")
or [norm(key) for key in model.answer_sequence_keys]
)
raw_answer_sequence_prompt_tokens = checkpoint.tensors.get("answer_sequence_prompt_tokens", [])
if np is not None and hasattr(raw_answer_sequence_prompt_tokens, "shape"):
model.answer_sequence_prompt_tokens = raw_answer_sequence_prompt_tokens.astype(int, copy=False)
else:
model.answer_sequence_prompt_tokens = [
[int(value) for value in row] for row in raw_answer_sequence_prompt_tokens
]
raw_answer_sequence_tokens = checkpoint.tensors.get("answer_sequence_tokens", [])
if np is not None and hasattr(raw_answer_sequence_tokens, "shape"):
model.answer_sequence_tokens = raw_answer_sequence_tokens.astype(int, copy=False)
else:
model.answer_sequence_tokens = [
[int(value) for value in row] for row in raw_answer_sequence_tokens
]
model.answer_sequence_token_id_rows = None
raw_fingerprints = checkpoint.tensors.get("answer_fingerprint_hashes", [])
model.answer_fingerprint_hashes = model._coerce_answer_fingerprint_hashes(
raw_fingerprints
)
model.answer_fingerprint_token_lengths = None
model.answer_fingerprint_token_sequences_by_length = None
if not model.answer_fingerprint_hashes:
model._refresh_answer_fingerprint_hashes()
model.answer_similarity_keys_array = _runtime_matrix_tensor("answer_similarity_keys")
model.answer_similarity_key_norms_array = _runtime_vector_tensor("answer_similarity_key_norms")
model.answer_start_similarity_keys_array = _runtime_matrix_tensor("answer_start_similarity_keys")
model.answer_start_similarity_key_norms_array = _runtime_vector_tensor("answer_start_similarity_key_norms")
model.answer_sequence_similarity_keys_array = _runtime_matrix_tensor("answer_sequence_similarity_keys")
model.answer_sequence_similarity_key_norms_array = _runtime_vector_tensor("answer_sequence_similarity_key_norms")
model.transition_id_tables = model._deserialize_transition_id_tables_from_tensors(
checkpoint.tensors
)
if model.transition_id_tables is not None:
model.transition_tables = {order: {} for order in sorted(TRANSITION_ORDERS)}
else:
model.transition_tables = model._deserialize_transition_tables(
json.loads(metadata.get("transition_tables", "{}"))
)
model._refresh_numeric_caches()
return model
def _collect_training_examples(
self,
tokens: list[str],
) -> tuple[list[Vector], list[Vector], list[int]]:
assert self.embedding_model is not None
if np is not None:
hidden_states = [
np.zeros(self.config.state_dim, dtype=np.float64)
for _ in self.config.timescales
]
context_traces = [
np.zeros(self.config.embedding_dim, dtype=np.float64)
for _ in self.config.timescales
]
zero_embedding: Vector | object = np.zeros(self.config.embedding_dim, dtype=np.float64)
else:
hidden_states = [zeros_vector(self.config.state_dim) for _ in self.config.timescales]
context_traces = [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales]
zero_embedding = zeros_vector(self.config.embedding_dim)
states: list[Vector] = []
labels: list[Vector] = []
label_ids: list[int] = []
token_ids = [
self.embedding_model.token_to_id.get(token, -1)
for token in tokens
]
example_count = max(0, len(tokens) - 1)
stride = 1
if self.config.max_training_examples and example_count > self.config.max_training_examples:
stride = max(
1,
(example_count + self.config.max_training_examples - 1) // self.config.max_training_examples,
)
for index in range(len(tokens) - 1):
token = tokens[index]
token_id = token_ids[index]
embedding = (
self.embedding_model.embeddings[token_id]
if token_id >= 0
else zero_embedding
)
trace_embedding = self._trace_embedding_from_token_id(embedding, token_id)
hidden_states, context_traces, combined_state = self._step_hidden_states_from_embedding(
hidden_states,
context_traces,
embedding,
trace_embedding=trace_embedding,
)
if stride > 1 and index % stride != 0 and index != len(tokens) - 2:
continue
states.append(combined_state)
next_token_id = token_ids[index + 1]
labels.append(self._one_hot_from_id(next_token_id))
label_ids.append(next_token_id)
if self.config.max_training_examples and len(states) > self.config.max_training_examples:
states = states[: self.config.max_training_examples]
labels = labels[: self.config.max_training_examples]
label_ids = label_ids[: self.config.max_training_examples]
return states, labels, label_ids
def _is_punctuation_piece(self, piece: str) -> bool:
return bool(piece) and all(character in string.punctuation for character in piece)
def _encode_context(self, tokens: list[str]) -> Vector:
return self._masked_decode_state(self._build_decode_state(tokens))
def _build_decode_state(self, tokens: list[str]) -> DecodeState:
assert self.memory_units is not None
state = DecodeState(
hidden_states=(
[
np.zeros(self.config.state_dim, dtype=np.float64)
for _ in self.config.timescales
]
if np is not None
else [zeros_vector(self.config.state_dim) for _ in self.config.timescales]
),
context_traces=(
[
np.zeros(self.config.embedding_dim, dtype=np.float64)
for _ in self.config.timescales
]
if np is not None
else [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales]
),
combined_state=self._zero_combined_state(),
context_tokens=[],
)
for token in tokens:
self._advance_decode_state(state, token)
self._apply_sparse_context_anchor(state)
return state
def _advance_decode_state(self, state: DecodeState, token: str) -> DecodeState:
next_hidden_states, next_context_traces, combined_state = self._step_hidden_states(
state.hidden_states,
state.context_traces,
token,
)
state.hidden_states = next_hidden_states
state.context_traces = next_context_traces
state.combined_state = combined_state
state.context_tokens.append(token)
if token == "<answer>":
state.answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:]
state.answer_matches = None
state.answer_start_matches = None
state.answer_sequence_matches = None
state.prompt_answer_prior = None
state.prompt_answer_start_prior = None
return state
def _apply_sparse_context_anchor(self, state: DecodeState) -> None:
if (
np is None
or self.embedding_model is None
or state.answer_anchor_state is None
or not state.context_tokens
):
return
answer_index = _last_index(state.context_tokens, "<answer>")
if answer_index is None or answer_index <= 0:
return
context_ids = self._long_context_sparse_token_ids(state.context_tokens[:answer_index])
if len(context_ids) < SPARSE_CONTEXT_MIN_TOKENS:
return
query_id = context_ids[-1]
embeddings = np.asarray(self.embedding_model.embeddings, dtype=np.float32)
if embeddings.ndim != 2 or embeddings.shape[0] == 0:
return
selector = HashedSparseAttention(
embeddings,
k_neighbors=min(SPARSE_CONTEXT_TOP_K, len(context_ids)),
hash_bits=SPARSE_CONTEXT_HASH_BITS,
probe_radius=SPARSE_CONTEXT_PROBE_RADIUS,
candidate_multiplier=SPARSE_CONTEXT_CANDIDATE_MULTIPLIER,
)
token_ids = np.asarray(context_ids, dtype=np.int64)
selector.build_context_index(token_ids)
selection = selector.select_positions_cached(query_id)
if not selection.positions:
return
selected_ids = token_ids[np.asarray(selection.positions, dtype=np.int64)]
selected_embeddings = embeddings[selected_ids]
scores = np.asarray(selection.scores, dtype=np.float32)
scores -= float(scores.max())
weights = np.exp(scores)
weights /= max(float(weights.sum()), 1e-8)
sparse_embedding = weights @ selected_embeddings
blended_anchor = self._blend_sparse_embedding_into_combined_state(
state.answer_anchor_state,
sparse_embedding,
state_dim=self.config.state_dim,
embedding_dim=self.config.embedding_dim,
timescale_count=len(self.config.timescales),
blend=SPARSE_CONTEXT_TRACE_BLEND,
)
state.answer_anchor_state = blended_anchor
if state.context_tokens and state.context_tokens[-1] == "<answer>":
state.combined_state = blended_anchor.copy()
state.answer_matches = None
state.answer_start_matches = None
state.answer_sequence_matches = None
state.prompt_answer_prior = None
state.prompt_answer_start_prior = None
def _long_context_sparse_token_ids(self, tokens: Sequence[str]) -> list[int]:
assert self.embedding_model is not None
special_tokens = self.tokenizer.special_tokens if self.tokenizer is not None else set()
ids: list[int] = []
for token in tokens:
if token in special_tokens and token not in TOOL_PROTOCOL_TOKENS:
continue
token_id = self._token_id_for_token(token)
if token_id >= 0:
ids.append(token_id)
return ids
@staticmethod
def _blend_sparse_embedding_into_combined_state(
combined_state: Vector,
sparse_embedding: object,
*,
state_dim: int,
embedding_dim: int,
timescale_count: int,
blend: float,
) -> Vector:
if np is None:
return combined_state
state_array = np.asarray(combined_state, dtype=np.float32).copy()
sparse_array = np.asarray(sparse_embedding, dtype=np.float32)
if sparse_array.shape[0] != embedding_dim:
return combined_state
block_width = state_dim + embedding_dim
expected_width = block_width * timescale_count
if state_array.shape[0] != expected_width:
return combined_state
alpha = min(1.0, max(0.0, float(blend)))
for block_index in range(timescale_count):
trace_start = block_index * block_width + state_dim
trace_end = trace_start + embedding_dim
state_array[trace_start:trace_end] = (
(1.0 - alpha) * state_array[trace_start:trace_end]
+ alpha * sparse_array
)
return state_array.tolist()
def _masked_decode_state(self, state: DecodeState) -> Vector:
assert self.ternary_mask is not None
return apply_ternary_mask(state.combined_state, self.ternary_mask, self.ternary_scale)
def _masked_combined_state(self, combined_state: Vector) -> Vector:
assert self.ternary_mask is not None
return apply_ternary_mask(combined_state, self.ternary_mask, self.ternary_scale)
def _masked_decode_state_array(self, state: DecodeState) -> object:
assert np is not None
if self.ternary_mask_array is None:
return np.asarray(self._masked_decode_state(state), dtype=RUNTIME_ARRAY_DTYPE)
return (
np.asarray(state.combined_state, dtype=RUNTIME_ARRAY_DTYPE)
* self.ternary_scale
* self.ternary_mask_array
)
def _masked_combined_state_array(self, combined_state: Vector) -> object:
assert np is not None
if self.ternary_mask_array is None:
return np.asarray(self._masked_combined_state(combined_state), dtype=RUNTIME_ARRAY_DTYPE)
return (
np.asarray(combined_state, dtype=RUNTIME_ARRAY_DTYPE)
* self.ternary_scale
* self.ternary_mask_array
)
def _center_state_vector(self, state: Vector) -> Vector:
if not self.state_offset or len(self.state_offset) != len(state):
return state
return [value - self.state_offset[index] for index, value in enumerate(state)]
def _center_state_array(self, state: object) -> object:
assert np is not None
state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE)
if self.state_offset_array is None or self.state_offset_array.shape != state_array.shape:
return state_array
return state_array - self.state_offset_array
def _zero_combined_state(self) -> Vector:
return [0.0 for _ in range(self._combined_state_width())]
def _combined_state_width(self) -> int:
return (self.config.state_dim + self.config.embedding_dim) * len(self.config.timescales)
def _derive_trace_token_weights_from_counts(self, token_counts: dict[str, float]) -> Vector:
assert self.embedding_model is not None
assert self.tokenizer is not None
counts = [
float(token_counts.get(token, 0.0))
for token in self.embedding_model.id_to_token
]
positive_counts = sorted(value for value in counts if value > 0.0)
reference = (
positive_counts[len(positive_counts) // 2]
if positive_counts
else 1.0
)
weights: Vector = []
for token, count in zip(self.embedding_model.id_to_token, counts):
if token in TOOL_PROTOCOL_TOKENS:
weights.append(1.0)
elif token in self.tokenizer.special_tokens:
weights.append(0.0)
elif count <= 0.0:
weights.append(1.0)
else:
weight = (reference / count) ** 0.75
weights.append(max(0.08, min(4.8, weight)))
return weights
def _token_id_for_token(self, token: str) -> int:
assert self.embedding_model is not None
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None and token.lower() != token:
token_id = self.embedding_model.token_to_id.get(token.lower())
return int(token_id) if token_id is not None else -1
def _trace_embedding_from_token_id(
self,
embedding: Vector | object,
token_id: int,
) -> Vector | object:
if token_id < 0:
return embedding
if self.trace_embedding_table_array is not None:
return self.trace_embedding_table_array[token_id]
weight = self.trace_token_weights[token_id] if self.trace_token_weights is not None else 1.0
dimension = self.config.embedding_dim
if hasattr(embedding, "shape"):
trace_embedding = embedding * weight
for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES:
bucket = (token_id * bucket_multiplier + bucket_offset) % dimension
sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0
trace_embedding[bucket] += weight * TRACE_IDENTITY_SCALE * sign
return trace_embedding
trace_values = [float(value) * weight for value in embedding]
for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES:
bucket = (token_id * bucket_multiplier + bucket_offset) % dimension
sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0
trace_values[bucket] += weight * TRACE_IDENTITY_SCALE * sign
return trace_values
def _build_trace_embedding_table_array(self, embedding_array: object) -> object | None:
if np is None or self.trace_token_weights is None:
return None
values = np.asarray(embedding_array, dtype=np.float64)
if values.size == 0 or len(values.shape) != 2:
return None
weights = np.asarray(self.trace_token_weights, dtype=np.float64)
if weights.shape[0] != values.shape[0]:
return None
trace_values = values * weights[:, None]
if values.shape[1] <= 0:
return trace_values
token_ids = np.arange(values.shape[0], dtype=np.int64)
for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES:
buckets = ((token_ids * bucket_multiplier + bucket_offset) % values.shape[1]).astype(
np.int64,
copy=False,
)
signs = np.where(
((token_ids * sign_multiplier + sign_offset) & 1) == 0,
1.0,
-1.0,
)
np.add.at(trace_values, (token_ids, buckets), weights * TRACE_IDENTITY_SCALE * signs)
return trace_values
def _runtime_key_norms_array(
self,
key_array: object | None,
key_norms: list[float] | None,
) -> object | None:
assert np is not None
if key_norms is not None and len(key_norms) > 0:
return np.asarray(key_norms, dtype=RUNTIME_ARRAY_DTYPE)
if key_array is None:
return None
keys = np.asarray(key_array, dtype=RUNTIME_ARRAY_DTYPE)
if len(keys.shape) != 2 or keys.shape[0] == 0:
return None
return np.linalg.norm(keys, axis=1).astype(RUNTIME_ARRAY_DTYPE, copy=False)
def _runtime_vector_cache(self, cached: object | None, length: int) -> object | None:
assert np is not None
if cached is None or not hasattr(cached, "shape"):
return None
array = np.asarray(cached, dtype=RUNTIME_ARRAY_DTYPE)
if len(array.shape) != 1 or int(array.shape[0]) != int(length):
return None
return array
def _runtime_matrix_cache(
self,
cached: object | None,
rows: int,
width: int,
) -> object | None:
assert np is not None
if cached is None or not hasattr(cached, "shape"):
return None
array = np.asarray(cached, dtype=RUNTIME_ARRAY_DTYPE)
if (
len(array.shape) != 2
or int(array.shape[0]) != int(rows)
or int(array.shape[1]) != int(width)
):
return None
return array
def _refresh_numeric_caches(self) -> None:
if np is None:
self.ternary_mask_array = None
self.readout_weights_array = None
self.readout_bias_array = None
self.prompt_answer_weights_array = None
self.prompt_answer_bias_array = None
self.prompt_answer_start_weights_array = None
self.prompt_answer_start_bias_array = None
self.trace_token_weights_array = None
self.trace_embedding_table_array = None
self.preference_bias_array = None
self.preference_valid_mask_array = None
self.state_offset_array = None
self.associative_keys_array = None
self.associative_key_norms_array = None
self.associative_values_array = None
self.associative_valid_mask_array = None
self.answer_keys_array = None
self.answer_key_norms_array = None
self.answer_similarity_keys_array = None
self.answer_similarity_key_norms_array = None
self.answer_similarity_mask_array = None
self.answer_values_array = None
self.answer_valid_mask_array = None
self.answer_start_keys_array = None
self.answer_start_key_norms_array = None
self.answer_start_similarity_keys_array = None
self.answer_start_similarity_key_norms_array = None
self.answer_start_values_array = None
self.answer_start_valid_mask_array = None
self.answer_sequence_keys_array = None
self.answer_sequence_key_norms_array = None
self.answer_sequence_similarity_keys_array = None
self.answer_sequence_similarity_key_norms_array = None
self.answer_sequence_prompt_tokens_array = None
self.answer_sequence_tokens_array = None
self.answer_sequence_prompt_weight_maps = None
self.answer_sequence_prompt_weight_norms = None
self.answer_sequence_prompt_bigram_sets = None
self.answer_sequence_prompt_trigram_sets = None
self.answer_sequence_prompt_number_sets = None
self.answer_sequence_prompt_inverted_index = None
self._refresh_answer_sequence_prompt_overlap_cache()
self.prompt_overlap_valid_token_mask_array = None
return
cached_associative_key_norms_array = self.associative_key_norms_array
cached_answer_key_norms_array = self.answer_key_norms_array
cached_answer_similarity_keys_array = self.answer_similarity_keys_array
cached_answer_similarity_key_norms_array = self.answer_similarity_key_norms_array
cached_answer_start_key_norms_array = self.answer_start_key_norms_array
cached_answer_start_similarity_keys_array = self.answer_start_similarity_keys_array
cached_answer_start_similarity_key_norms_array = self.answer_start_similarity_key_norms_array
cached_answer_sequence_key_norms_array = self.answer_sequence_key_norms_array
cached_answer_sequence_similarity_keys_array = self.answer_sequence_similarity_keys_array
cached_answer_sequence_similarity_key_norms_array = self.answer_sequence_similarity_key_norms_array
self.ternary_mask_array = (
np.asarray(self.ternary_mask, dtype=RUNTIME_ARRAY_DTYPE)
if self.ternary_mask is not None
else None
)
self.readout_weights_array = (
np.asarray(self.readout_weights, dtype=RUNTIME_ARRAY_DTYPE)
if self.readout_weights is not None
else None
)
self.readout_bias_array = (
np.asarray(self.readout_bias, dtype=RUNTIME_ARRAY_DTYPE)
if self.readout_bias is not None
else None
)
self.prompt_answer_weights_array = (
np.asarray(self.prompt_answer_weights, dtype=RUNTIME_ARRAY_DTYPE)
if self.prompt_answer_weights is not None
and len(self.prompt_answer_weights) > 0
else None
)
self.prompt_answer_bias_array = (
np.asarray(self.prompt_answer_bias, dtype=RUNTIME_ARRAY_DTYPE)
if self.prompt_answer_bias is not None
else None
)
self.prompt_answer_start_weights_array = (
np.asarray(self.prompt_answer_start_weights, dtype=RUNTIME_ARRAY_DTYPE)
if self.prompt_answer_start_weights is not None
and len(self.prompt_answer_start_weights) > 0
else None
)
self.prompt_answer_start_bias_array = (
np.asarray(self.prompt_answer_start_bias, dtype=RUNTIME_ARRAY_DTYPE)
if self.prompt_answer_start_bias is not None
else None
)
self.trace_token_weights_array = (
np.asarray(self.trace_token_weights, dtype=RUNTIME_ARRAY_DTYPE)
if self.trace_token_weights is not None
else None
)
trace_embedding_table = (
self._build_trace_embedding_table_array(self.embedding_model.embeddings)
if self.embedding_model is not None and self.trace_token_weights is not None
else None
)
self.trace_embedding_table_array = (
trace_embedding_table.astype(RUNTIME_ARRAY_DTYPE, copy=False)
if trace_embedding_table is not None
else None
)
self.preference_bias_array = (
np.asarray(self.preference_bias, dtype=RUNTIME_ARRAY_DTYPE)
if self.preference_bias is not None
else None
)
self.preference_valid_mask_array = (
np.asarray(
[
self._eligible_preference_token(token)
for token in self.embedding_model.id_to_token
],
dtype=bool,
)
if self.embedding_model is not None and self.tokenizer is not None
else None
)
self.state_offset_array = (
np.asarray(self.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
if self.state_offset is not None
else None
)
self.associative_keys_array = (
np.asarray(self.associative_keys, dtype=RUNTIME_ARRAY_DTYPE)
if self.associative_keys is not None and len(self.associative_keys) > 0
else None
)
associative_key_norms_cache = (
self._runtime_vector_cache(
cached_associative_key_norms_array,
int(self.associative_keys_array.shape[0]),
)
if self.associative_keys_array is not None
else None
)
self.associative_key_norms_array = (
associative_key_norms_cache
if associative_key_norms_cache is not None
else self._runtime_key_norms_array(
self.associative_keys_array,
self.associative_key_norms,
)
)
self.associative_values_array = (
np.asarray(self.associative_values, dtype=np.int64)
if self.associative_values is not None and len(self.associative_values) > 0
else None
)
self.associative_valid_mask_array = (
self.associative_values_array >= 0
if self.associative_values_array is not None
else None
)
self.answer_keys_array = (
np.asarray(self.answer_keys, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_keys is not None and len(self.answer_keys) > 0
else None
)
answer_key_norms_cache = (
self._runtime_vector_cache(
cached_answer_key_norms_array,
int(self.answer_keys_array.shape[0]),
)
if self.answer_keys_array is not None
else None
)
self.answer_key_norms_array = (
answer_key_norms_cache
if answer_key_norms_cache is not None
else self._runtime_key_norms_array(
self.answer_keys_array,
self.answer_key_norms,
)
)
self.answer_similarity_keys_array = None
self.answer_similarity_key_norms_array = None
self.answer_similarity_mask_array = None
if self.answer_keys_array is not None and len(self.answer_keys_array.shape) == 2:
width = int(self.answer_keys_array.shape[1])
block_width = self.config.state_dim + self.config.embedding_dim
expected_width = block_width * len(self.config.timescales)
if block_width > 0 and width == expected_width:
mask = np.zeros(width, dtype=RUNTIME_ARRAY_DTYPE)
for scale_index in range(len(self.config.timescales)):
start = scale_index * block_width + self.config.state_dim
end = start + self.config.embedding_dim
mask[start:end] = 1.0
self.answer_similarity_mask_array = mask
answer_similarity_keys_cache = self._runtime_matrix_cache(
cached_answer_similarity_keys_array,
int(self.answer_keys_array.shape[0]),
width,
)
answer_similarity_key_norms_cache = self._runtime_vector_cache(
cached_answer_similarity_key_norms_array,
int(self.answer_keys_array.shape[0]),
)
if (
answer_similarity_keys_cache is not None
and answer_similarity_key_norms_cache is not None
):
self.answer_similarity_keys_array = answer_similarity_keys_cache
self.answer_similarity_key_norms_array = answer_similarity_key_norms_cache
else:
self.answer_similarity_keys_array = self.answer_keys_array * mask[None, :]
self.answer_similarity_key_norms_array = np.linalg.norm(
self.answer_similarity_keys_array,
axis=1,
).astype(RUNTIME_ARRAY_DTYPE, copy=False)
self.answer_values_array = (
np.asarray(self.answer_values, dtype=np.int64)
if self.answer_values is not None and len(self.answer_values) > 0
else None
)
self.answer_valid_mask_array = (
self.answer_values_array >= 0
if self.answer_values_array is not None
else None
)
self.answer_start_keys_array = (
np.asarray(self.answer_start_keys, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_start_keys is not None and len(self.answer_start_keys) > 0
else None
)
answer_start_key_norms_cache = (
self._runtime_vector_cache(
cached_answer_start_key_norms_array,
int(self.answer_start_keys_array.shape[0]),
)
if self.answer_start_keys_array is not None
else None
)
self.answer_start_key_norms_array = (
answer_start_key_norms_cache
if answer_start_key_norms_cache is not None
else self._runtime_key_norms_array(
self.answer_start_keys_array,
self.answer_start_key_norms,
)
)
self.answer_start_similarity_keys_array = None
self.answer_start_similarity_key_norms_array = None
if (
self.answer_start_keys_array is not None
and len(self.answer_start_keys_array.shape) == 2
and self.answer_similarity_mask_array is not None
and int(self.answer_start_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0])
):
answer_start_similarity_keys_cache = self._runtime_matrix_cache(
cached_answer_start_similarity_keys_array,
int(self.answer_start_keys_array.shape[0]),
int(self.answer_start_keys_array.shape[1]),
)
answer_start_similarity_key_norms_cache = self._runtime_vector_cache(
cached_answer_start_similarity_key_norms_array,
int(self.answer_start_keys_array.shape[0]),
)
if (
answer_start_similarity_keys_cache is not None
and answer_start_similarity_key_norms_cache is not None
):
self.answer_start_similarity_keys_array = answer_start_similarity_keys_cache
self.answer_start_similarity_key_norms_array = answer_start_similarity_key_norms_cache
else:
self.answer_start_similarity_keys_array = (
self.answer_start_keys_array * self.answer_similarity_mask_array[None, :]
)
self.answer_start_similarity_key_norms_array = np.linalg.norm(
self.answer_start_similarity_keys_array,
axis=1,
).astype(RUNTIME_ARRAY_DTYPE, copy=False)
self.answer_start_values_array = (
np.asarray(self.answer_start_values, dtype=np.int64)
if self.answer_start_values is not None and len(self.answer_start_values) > 0
else None
)
self.answer_start_valid_mask_array = (
self.answer_start_values_array >= 0
if self.answer_start_values_array is not None
else None
)
self.answer_sequence_keys_array = (
np.asarray(self.answer_sequence_keys, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_sequence_keys is not None and len(self.answer_sequence_keys) > 0
else None
)
answer_sequence_key_norms_cache = (
self._runtime_vector_cache(
cached_answer_sequence_key_norms_array,
int(self.answer_sequence_keys_array.shape[0]),
)
if self.answer_sequence_keys_array is not None
else None
)
self.answer_sequence_key_norms_array = (
answer_sequence_key_norms_cache
if answer_sequence_key_norms_cache is not None
else self._runtime_key_norms_array(
self.answer_sequence_keys_array,
self.answer_sequence_key_norms,
)
)
self.answer_sequence_similarity_keys_array = None
self.answer_sequence_similarity_key_norms_array = None
if (
self.answer_sequence_keys_array is not None
and len(self.answer_sequence_keys_array.shape) == 2
and self.answer_similarity_mask_array is not None
and int(self.answer_sequence_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0])
):
answer_sequence_similarity_keys_cache = self._runtime_matrix_cache(
cached_answer_sequence_similarity_keys_array,
int(self.answer_sequence_keys_array.shape[0]),
int(self.answer_sequence_keys_array.shape[1]),
)
answer_sequence_similarity_key_norms_cache = self._runtime_vector_cache(
cached_answer_sequence_similarity_key_norms_array,
int(self.answer_sequence_keys_array.shape[0]),
)
if (
answer_sequence_similarity_keys_cache is not None
and answer_sequence_similarity_key_norms_cache is not None
):
self.answer_sequence_similarity_keys_array = answer_sequence_similarity_keys_cache
self.answer_sequence_similarity_key_norms_array = answer_sequence_similarity_key_norms_cache
else:
self.answer_sequence_similarity_keys_array = (
self.answer_sequence_keys_array * self.answer_similarity_mask_array[None, :]
)
self.answer_sequence_similarity_key_norms_array = np.linalg.norm(
self.answer_sequence_similarity_keys_array,
axis=1,
).astype(RUNTIME_ARRAY_DTYPE, copy=False)
self.answer_sequence_tokens_array = (
np.asarray(self.answer_sequence_tokens, dtype=np.int64)
if self.answer_sequence_tokens is not None and len(self.answer_sequence_tokens) > 0
else None
)
self.answer_sequence_prompt_tokens_array = (
np.asarray(self.answer_sequence_prompt_tokens, dtype=np.int64)
if self.answer_sequence_prompt_tokens is not None
and len(self.answer_sequence_prompt_tokens) > 0
else None
)
self.prompt_overlap_valid_token_mask_array = None
if not self._defer_answer_sequence_prompt_overlap_cache():
self._refresh_answer_sequence_prompt_overlap_cache()
else:
self._refresh_answer_sequence_prompt_overlap_cache()
def _defer_answer_sequence_prompt_overlap_cache(self) -> bool:
if self.answer_sequence_prompt_tokens is None:
return False
try:
row_count = len(self.answer_sequence_prompt_tokens)
except TypeError:
return False
return (
row_count > ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT
and np is not None
and self.answer_sequence_prompt_tokens_array is not None
)
def _prompt_overlap_valid_token_mask(self) -> object | None:
if np is None or self.embedding_model is None:
return None
if (
self.prompt_overlap_valid_token_mask_array is not None
and int(self.prompt_overlap_valid_token_mask_array.shape[0]) == len(self.embedding_model.id_to_token)
):
return self.prompt_overlap_valid_token_mask_array
mask = np.fromiter(
(
not self._should_skip_prompt_overlap_token(token)
for token in self.embedding_model.id_to_token
),
dtype=bool,
count=len(self.embedding_model.id_to_token),
)
self.prompt_overlap_valid_token_mask_array = mask
return mask
def _answer_prompt_row_ids_from_array(self) -> tuple[dict[int, list[int]], list[list[int]] | None] | None:
if (
np is None
or self.answer_sequence_prompt_tokens_array is None
or self.trace_token_weights is None
or self.embedding_model is None
):
return None
rows = np.asarray(self.answer_sequence_prompt_tokens_array, dtype=np.int64)
if len(rows.shape) != 2 or rows.size == 0:
return {}, [] if rows.shape[0] <= ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT else None
vocab_size = len(self.trace_token_weights)
if vocab_size <= 0:
return {}, [] if rows.shape[0] <= ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT else None
valid_token_mask = self._prompt_overlap_valid_token_mask()
if valid_token_mask is None:
return None
bounded = (rows >= 0) & (rows < vocab_size)
clipped = np.clip(rows, 0, max(0, vocab_size - 1))
bounded &= valid_token_mask[clipped]
row_positions, column_positions = np.nonzero(bounded)
if row_positions.size == 0:
empty_rows = [[] for _ in range(int(rows.shape[0]))] if rows.shape[0] <= ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT else None
return {}, empty_rows
token_values = rows[row_positions, column_positions].astype(np.int64, copy=False)
order = np.lexsort((row_positions, token_values))
token_values = token_values[order]
row_positions = row_positions[order]
unique = np.ones(token_values.shape[0], dtype=bool)
unique[1:] = (token_values[1:] != token_values[:-1]) | (row_positions[1:] != row_positions[:-1])
token_values = token_values[unique]
row_positions = row_positions[unique]
boundaries = np.flatnonzero(token_values[1:] != token_values[:-1]) + 1
token_groups = np.split(token_values, boundaries)
row_groups = np.split(row_positions, boundaries)
inverted = {
int(token_group[0]): row_group.astype(np.int64, copy=False).tolist()
for token_group, row_group in zip(token_groups, row_groups)
if token_group.size
}
if rows.shape[0] > ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT:
return inverted, None
row_id_lists: list[list[int]] = [[] for _ in range(int(rows.shape[0]))]
for token_id, row_index in zip(token_values.tolist(), row_positions.tolist()):
row_id_lists[int(row_index)].append(int(token_id))
return inverted, row_id_lists
def _refresh_answer_sequence_prompt_overlap_cache(self) -> None:
self.answer_sequence_prompt_weight_maps = None
self.answer_sequence_prompt_weight_norms = None
self.answer_sequence_prompt_bigram_sets = None
self.answer_sequence_prompt_trigram_sets = None
self.answer_sequence_prompt_number_sets = None
self.answer_sequence_prompt_inverted_index = None
self.answer_sequence_prompt_specificity = None
if self.answer_sequence_prompt_tokens is None or self.trace_token_weights is None:
return
array_index = self._answer_prompt_row_ids_from_array()
if array_index is not None:
inverted, row_id_lists = array_index
total_rows = (
int(self.answer_sequence_prompt_tokens_array.shape[0])
if self.answer_sequence_prompt_tokens_array is not None
else len(row_id_lists or [])
)
else:
inverted = {}
row_id_lists = []
for row in self.answer_sequence_prompt_tokens:
row_values = row.tolist() if hasattr(row, "tolist") else row
row_ids: list[int] = []
for raw_token_id in row_values:
token_id = int(raw_token_id)
if token_id < 0 or token_id >= len(self.trace_token_weights):
continue
if self.embedding_model is not None and self._should_skip_prompt_overlap_token(
self.embedding_model.id_to_token[token_id]
):
continue
row_ids.append(token_id)
sequence_index = len(row_id_lists)
for token_id in set(row_ids):
inverted.setdefault(token_id, []).append(sequence_index)
row_id_lists.append(row_ids)
total_rows = len(row_id_lists)
specificity = {
token_id: self._prompt_overlap_token_specificity(len(indices), total_rows)
for token_id, indices in inverted.items()
}
self.answer_sequence_prompt_inverted_index = inverted
self.answer_sequence_prompt_specificity = specificity
if total_rows > ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT:
return
if row_id_lists is None:
return
weight_maps: list[dict[int, float]] = []
weight_norms: list[float] = []
bigram_sets: list[set[tuple[int, int]]] = []
trigram_sets: list[set[tuple[int, int, int]]] = []
number_sets: list[set[str]] = []
for row_index, row_ids in enumerate(row_id_lists):
row_weights: dict[int, float] = {}
for token_id in row_ids:
row_weights[token_id] = max(
row_weights.get(token_id, 0.0),
float(self.trace_token_weights[token_id]) * specificity.get(token_id, 1.0),
)
weight_maps.append(row_weights)
weight_norms.append(sum(value * value for value in row_weights.values()) ** 0.5)
bigram_sets.append(
{
(row_ids[index], row_ids[index + 1])
for index in range(len(row_ids) - 1)
}
)
trigram_sets.append(
{
(row_ids[index], row_ids[index + 1], row_ids[index + 2])
for index in range(len(row_ids) - 2)
}
)
raw_row = self.answer_sequence_prompt_tokens[row_index]
raw_values = raw_row.tolist() if hasattr(raw_row, "tolist") else raw_row
raw_ids = [
int(value)
for value in raw_values
if 0 <= int(value) < len(self.embedding_model.id_to_token)
]
number_sets.append(self._number_strings_from_token_ids(raw_ids))
self.answer_sequence_prompt_weight_maps = weight_maps
self.answer_sequence_prompt_weight_norms = weight_norms
self.answer_sequence_prompt_bigram_sets = bigram_sets
self.answer_sequence_prompt_trigram_sets = trigram_sets
self.answer_sequence_prompt_number_sets = number_sets
@staticmethod
def _prompt_overlap_token_specificity(document_frequency: int, total_documents: int) -> float:
if document_frequency <= 0 or total_documents <= 0:
return 1.0
coverage = min(1.0, document_frequency / total_documents)
return max(0.02, 1.0 - (coverage ** 0.5))
def _number_strings_from_token_ids(self, token_ids: list[int]) -> set[str]:
assert self.embedding_model is not None
tokens = [
self.embedding_model.id_to_token[token_id]
for token_id in token_ids
if 0 <= token_id < len(self.embedding_model.id_to_token)
]
return self._number_strings_from_tokens(tokens)
def _number_strings_from_tokens(self, tokens: list[str]) -> set[str]:
numbers: set[str] = set()
current = ""
for token in tokens:
if self.tokenizer is not None and token in self.tokenizer.special_tokens:
if current:
numbers.add(current)
current = ""
continue
rendered = self._render_token(token)
digits = "".join(character for character in rendered if character.isdigit())
starts_number = self._starts_new_word(token) if self.tokenizer is not None else True
if digits and starts_number:
if current:
numbers.add(current)
current = digits
elif digits and current:
current += digits
else:
if current:
numbers.add(current)
current = ""
if current:
numbers.add(current)
return numbers
@staticmethod
def _numeric_prompt_can_match(query_numbers: set[str], row_numbers: set[str]) -> bool:
if not query_numbers:
return True
if not row_numbers:
return False
return query_numbers.issubset(row_numbers)
def _vector_answer_sequence_candidate_indices(
self,
query_token_ids: object,
) -> list[int] | None:
if (
np is None
or self.answer_sequence_prompt_tokens_array is None
or not hasattr(self.answer_sequence_prompt_tokens_array, "shape")
):
return None
query_ids = np.asarray(list(query_token_ids), dtype=np.int64)
if query_ids.size == 0:
return []
prompt_array = self.answer_sequence_prompt_tokens_array
if len(prompt_array.shape) != 2 or prompt_array.shape[0] == 0:
return None
mask = np.isin(prompt_array, query_ids).any(axis=1)
return [int(index) for index in np.flatnonzero(mask)]
def _vector_answer_sequence_local_frequency(
self,
token_id: int,
candidate_indices: list[int],
) -> int | None:
if (
np is None
or self.answer_sequence_prompt_tokens_array is None
or not hasattr(self.answer_sequence_prompt_tokens_array, "shape")
or not candidate_indices
):
return None
rows = self.answer_sequence_prompt_tokens_array[
np.asarray(candidate_indices, dtype=np.int64)
]
return int(np.any(rows == int(token_id), axis=1).sum())
def _apply_readout_fast(self, state: Vector) -> Vector:
if self.readout_weights_array is None or np is None:
assert self.readout_weights is not None
centered_state = self._center_state_vector(state)
logits = apply_readout(self.readout_weights, centered_state)
if self.readout_bias:
logits = [
value + self.readout_bias[index]
for index, value in enumerate(logits)
]
return logits
state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE)
if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape:
state_array = state_array - self.state_offset_array
logits = self.readout_weights_array @ state_array
if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape:
logits = logits + self.readout_bias_array
return logits.tolist()
def _apply_readout_array(self, state: object) -> object:
assert np is not None
assert self.readout_weights_array is not None
state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE)
if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape:
state_array = state_array - self.state_offset_array
logits = self.readout_weights_array @ state_array
if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape:
logits = logits + self.readout_bias_array
return logits
def _step_hidden_states(
self,
hidden_states: list[Vector],
context_traces: list[Vector],
token: str,
) -> tuple[list[Vector], list[Vector], Vector]:
assert self.embedding_model is not None
assert self.tokenizer is not None
token_id = self._token_id_for_token(token)
embedding = self.embedding_model.vector(token)
trace_embedding = self._trace_embedding_from_token_id(embedding, token_id)
return self._step_hidden_states_from_embedding(
hidden_states,
context_traces,
embedding,
trace_embedding=trace_embedding,
)
def _step_hidden_states_from_embedding(
self,
hidden_states: list[Vector],
context_traces: list[Vector],
embedding: Vector | object,
*,
trace_embedding: Vector | object | None = None,
) -> tuple[list[Vector], list[Vector], Vector]:
assert self.memory_units is not None
if trace_embedding is None:
trace_embedding = embedding
if np is not None and hidden_states and hasattr(hidden_states[0], "shape"):
embedding_array = (
embedding
if hasattr(embedding, "shape")
else np.asarray(embedding, dtype=np.float64)
)
trace_embedding_array = (
trace_embedding
if hasattr(trace_embedding, "shape")
else np.asarray(trace_embedding, dtype=np.float64)
)
drive = analytical_embedding_drive_fast(embedding_array, self.config.state_dim)
next_states: list[Vector] = []
next_traces: list[Vector] = []
combined_state: Vector = []
for unit, state, trace in zip(self.memory_units, hidden_states, context_traces):
next_state = unit.step_vector_fast(state, drive)
decay = 1.0 / (1.0 + unit.timescale)
next_trace = trace + ((1.0 - decay) * trace_embedding_array)
next_states.append(next_state)
next_traces.append(next_trace)
combined_state.extend(next_state.tolist())
combined_state.extend(next_trace.tolist())
return next_states, next_traces, combined_state
embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else embedding
trace_embedding_vector = (
trace_embedding.tolist()
if hasattr(trace_embedding, "tolist")
else trace_embedding
)
drive = analytical_embedding_drive(embedding_vector, self.config.state_dim)
next_states: list[Vector] = []
next_traces: list[Vector] = []
combined_state: Vector = []
for unit, state, trace in zip(self.memory_units, hidden_states, context_traces):
next_state = unit.step_vector(state, drive)
decay = 1.0 / (1.0 + unit.timescale)
next_trace = [
previous + ((1.0 - decay) * value)
for previous, value in zip(trace, trace_embedding_vector)
]
next_states.append(next_state)
next_traces.append(next_trace)
combined_state.extend(next_state)
combined_state.extend(next_trace)
return next_states, next_traces, combined_state
def _one_hot(self, token: str) -> Vector:
assert self.embedding_model is not None
return self._one_hot_from_id(self.embedding_model.token_to_id.get(token, -1))
def _one_hot_from_id(self, token_id: int) -> Vector:
assert self.embedding_model is not None
vector = [0.0 for _ in self.embedding_model.id_to_token]
if token_id >= 0:
vector[token_id] = 1.0
return vector
def _blend_probabilities(
self,
base: Vector,
answer: Vector,
associative: Vector,
transition: Vector,
copy: Vector,
source_evidence: Vector,
preference: Vector,
*,
transition_order: int | None,
generated_count: int = 0,
answer_locked: bool = False,
answer_guided_start: bool = False,
copy_guided_start: bool = False,
) -> tuple[Vector, dict[str, float]]:
base_weight = FAST_BASE_BLEND
answer_weight = FAST_ANSWER_BLEND
associative_weight = FAST_ASSOCIATIVE_BLEND
transition_weight = FAST_TRANSITION_BLEND
copy_weight = FAST_COPY_BLEND
source_evidence_weight = FAST_SOURCE_EVIDENCE_BLEND
preference_weight = FAST_PREFERENCE_BLEND
source_grounded = any(value > 0.0 for value in source_evidence)
if answer_locked:
base_weight *= 0.005
answer_weight *= 250.0
associative_weight *= 0.05
transition_weight *= 0.005
copy_weight *= 0.005
source_evidence_weight *= 0.05
preference_weight *= 0.05
elif answer_guided_start:
base_weight *= 0.45
answer_weight *= 3.1
associative_weight *= 0.2
transition_weight *= 0.35
copy_weight *= 0.2
source_evidence_weight *= 1.1
preference_weight *= 0.2
elif copy_guided_start:
base_weight *= 0.55
answer_weight *= 0.35
associative_weight *= 0.4
transition_weight *= 0.35
copy_weight *= 4.5
preference_weight *= 0.6
elif generated_count > 0:
answer_weight *= 0.32
transition_weight *= 2.0
copy_weight *= 0.75
source_evidence_weight *= 0.85
if source_grounded:
base_weight *= 0.45
answer_weight *= 0.35
associative_weight *= 0.50
transition_weight *= 0.25
copy_weight *= 0.50
source_evidence_weight *= 3.50
if source_grounded:
base_weight *= 0.60
answer_weight *= 0.35
associative_weight *= 0.50
transition_weight *= 0.80
copy_weight *= 0.20
source_evidence_weight *= 1.80
else:
source_evidence_weight = 0.0
if transition_order is None:
answer_weight *= 1.1
associative_weight *= 0.75
copy_weight += 0.02
elif transition_order <= 2:
answer_weight *= 1.15
associative_weight *= 0.65
transition_weight *= 0.55
copy_weight += 0.01
elif transition_order >= 5:
transition_weight *= 1.25
sources: list[tuple[str, float, Vector]] = [("base", base_weight, base)]
if any(value > 0.0 for value in answer):
sources.append(("answer", answer_weight, answer))
if any(value > 0.0 for value in associative):
sources.append(("associative", associative_weight, associative))
if any(value > 0.0 for value in transition):
sources.append(("transition", transition_weight, transition))
if any(value > 0.0 for value in copy):
sources.append(("copy", copy_weight, copy))
if any(value > 0.0 for value in source_evidence):
sources.append(("source_evidence", source_evidence_weight, source_evidence))
if any(value > 0.0 for value in preference):
sources.append(("preference", preference_weight, preference))
total_weight = sum(weight for _, weight, _ in sources)
blended = [0.0 for _ in base]
blend_weights: dict[str, float] = {}
for name, weight, source in sources:
normalized_weight = weight / total_weight if total_weight else 0.0
blend_weights[name] = normalized_weight
for index, value in enumerate(source):
blended[index] += normalized_weight * value
return _normalize_vector(blended), blend_weights
def _blend_probability_arrays(
self,
base: object,
answer: object,
associative: object,
transition: object,
copy: object,
source_evidence: object,
preference: object,
*,
transition_order: int | None,
generated_count: int = 0,
answer_locked: bool = False,
answer_guided_start: bool = False,
copy_guided_start: bool = False,
) -> tuple[object, dict[str, float]]:
assert np is not None
base_weight = FAST_BASE_BLEND
answer_weight = FAST_ANSWER_BLEND
associative_weight = FAST_ASSOCIATIVE_BLEND
transition_weight = FAST_TRANSITION_BLEND
copy_weight = FAST_COPY_BLEND
source_evidence_weight = FAST_SOURCE_EVIDENCE_BLEND
preference_weight = FAST_PREFERENCE_BLEND
source_grounded = bool(np.any(source_evidence > 0.0))
if answer_locked:
base_weight *= 0.005
answer_weight *= 250.0
associative_weight *= 0.05
transition_weight *= 0.005
copy_weight *= 0.005
source_evidence_weight *= 0.05
preference_weight *= 0.05
elif answer_guided_start:
base_weight *= 0.45
answer_weight *= 3.1
associative_weight *= 0.2
transition_weight *= 0.35
copy_weight *= 0.2
source_evidence_weight *= 1.1
preference_weight *= 0.2
elif copy_guided_start:
base_weight *= 0.55
answer_weight *= 0.35
associative_weight *= 0.4
transition_weight *= 0.35
copy_weight *= 4.5
preference_weight *= 0.6
elif generated_count > 0:
answer_weight *= 0.32
transition_weight *= 2.0
copy_weight *= 0.75
source_evidence_weight *= 0.85
if source_grounded:
base_weight *= 0.45
answer_weight *= 0.35
associative_weight *= 0.50
transition_weight *= 0.25
copy_weight *= 0.50
source_evidence_weight *= 3.50
if source_grounded:
base_weight *= 0.60
answer_weight *= 0.35
associative_weight *= 0.50
transition_weight *= 0.80
copy_weight *= 0.20
source_evidence_weight *= 1.80
else:
source_evidence_weight = 0.0
if transition_order is None:
answer_weight *= 1.1
associative_weight *= 0.75
copy_weight += 0.02
elif transition_order <= 2:
answer_weight *= 1.15
associative_weight *= 0.65
transition_weight *= 0.55
copy_weight += 0.01
elif transition_order >= 5:
transition_weight *= 1.25
sources: list[tuple[str, float, object]] = [("base", base_weight, base)]
if np.any(answer > 0.0):
sources.append(("answer", answer_weight, answer))
if np.any(associative > 0.0):
sources.append(("associative", associative_weight, associative))
if np.any(transition > 0.0):
sources.append(("transition", transition_weight, transition))
if np.any(copy > 0.0):
sources.append(("copy", copy_weight, copy))
if np.any(source_evidence > 0.0):
sources.append(("source_evidence", source_evidence_weight, source_evidence))
if np.any(preference > 0.0):
sources.append(("preference", preference_weight, preference))
total_weight = sum(weight for _, weight, _ in sources)
blended = np.zeros_like(base, dtype=np.float64)
blend_weights: dict[str, float] = {}
for name, weight, source in sources:
normalized_weight = weight / total_weight if total_weight else 0.0
blend_weights[name] = normalized_weight
blended += normalized_weight * source
total = float(blended.sum())
if total <= 0.0:
return base, blend_weights
return blended / total, blend_weights
def _score_associative_matches(
self,
state: Vector,
*,
limit: int = ASSOCIATIVE_TOP_K,
) -> list[tuple[float, int, int]]:
if (
self.associative_keys is None
or self.associative_values is None
or len(self.associative_keys) == 0
or len(self.associative_values) == 0
):
return []
if (
np is not None
and
self.associative_keys_array is not None
and self.associative_key_norms_array is not None
and self.associative_values_array is not None
and self.associative_valid_mask_array is not None
and limit > 0
):
state_array = self._center_state_array(state).astype(self.associative_keys_array.dtype, copy=False)
state_norm = float(np.linalg.norm(state_array))
if state_norm == 0.0:
return []
numerators = self.associative_keys_array @ state_array
denominators = self.associative_key_norms_array * state_norm
valid_mask = self.associative_valid_mask_array & (denominators > 0.0)
if np.any(valid_mask):
scores = np.zeros_like(numerators, dtype=self.associative_keys_array.dtype)
np.divide(numerators, denominators, out=scores, where=valid_mask)
positive_positions = np.flatnonzero(valid_mask & (scores > 0.0))
if positive_positions.size:
selected_positions = positive_positions
if positive_positions.size > limit:
partition = np.argpartition(scores[positive_positions], -limit)[-limit:]
selected_positions = positive_positions[partition]
ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]]
return [
(
float(scores[position]),
int(self.associative_values_array[position]),
int(position),
)
for position in ordered_positions
]
if self.associative_key_norms is None or len(self.associative_key_norms) == 0:
return []
state = self._center_state_vector(state)
state_norm = norm(state)
if state_norm == 0.0:
return []
scored: list[tuple[float, int, int]] = []
for example_index, (key, key_norm, token_id) in enumerate(
zip(self.associative_keys, self.associative_key_norms, self.associative_values)
):
if token_id < 0:
continue
denominator = state_norm * key_norm
if denominator == 0.0:
continue
similarity = dot(state, key) / denominator
if similarity > 0.0:
scored.append((similarity, token_id, example_index))
scored.sort(key=lambda item: item[0], reverse=True)
return scored[:limit]
def _associative_prior_from_matches(
self,
matches: list[tuple[float, int, int]],
) -> Vector:
assert self.embedding_model is not None
if not matches:
return [0.0 for _ in self.embedding_model.id_to_token]
prior = [0.0 for _ in self.embedding_model.id_to_token]
for similarity, token_id, _ in matches[:ASSOCIATIVE_TOP_K]:
prior[token_id] += similarity
return _normalize_vector(prior)
def _associative_prior(self, state: Vector) -> Vector:
return self._associative_prior_from_matches(self._score_associative_matches(state))
def _score_answer_matches(
self,
answer_anchor_state: Vector | None,
*,
limit: int = ANSWER_TOP_K,
) -> list[tuple[float, int, int]]:
return self._score_prompt_anchor_matches(
answer_anchor_state,
self.answer_keys,
self.answer_key_norms,
self.answer_values,
self.answer_keys_array,
self.answer_key_norms_array,
self.answer_values_array,
self.answer_valid_mask_array,
self.answer_similarity_keys_array,
self.answer_similarity_key_norms_array,
self.answer_similarity_mask_array,
limit=limit,
)
def _score_answer_start_matches(
self,
answer_anchor_state: Vector | None,
*,
limit: int = ANSWER_START_TOP_K,
) -> list[tuple[float, int, int]]:
matches = self._score_prompt_anchor_matches(
answer_anchor_state,
self.answer_start_keys,
self.answer_start_key_norms,
self.answer_start_values,
self.answer_start_keys_array,
self.answer_start_key_norms_array,
self.answer_start_values_array,
self.answer_start_valid_mask_array,
self.answer_start_similarity_keys_array,
self.answer_start_similarity_key_norms_array,
self.answer_similarity_mask_array,
limit=limit,
)
if matches:
return matches
return self._score_prompt_anchor_matches(
answer_anchor_state,
self.answer_start_keys,
self.answer_start_key_norms,
self.answer_start_values,
self.answer_start_keys_array,
self.answer_start_key_norms_array,
self.answer_start_values_array,
self.answer_start_valid_mask_array,
None,
None,
None,
limit=limit,
)
def _score_answer_sequence_matches(
self,
answer_anchor_state: Vector | None,
context_tokens: list[str],
*,
limit: int = ANSWER_START_TOP_K,
) -> list[tuple[float, int, int]]:
if (
answer_anchor_state is None
or self.answer_sequence_keys is None
or self.answer_sequence_key_norms is None
or self.answer_sequence_tokens is None
):
return []
values = list(range(len(self.answer_sequence_tokens)))
values_array = np.arange(len(values), dtype=np.int64) if np is not None else None
anchor_matches = self._score_prompt_anchor_matches(
answer_anchor_state,
self.answer_sequence_keys,
self.answer_sequence_key_norms,
values,
self.answer_sequence_keys_array,
self.answer_sequence_key_norms_array,
values_array,
values_array >= 0 if values_array is not None else None,
self.answer_sequence_similarity_keys_array,
self.answer_sequence_similarity_key_norms_array,
self.answer_similarity_mask_array,
limit=max(limit * 4, limit),
)
overlap_scores = self._answer_sequence_prompt_overlap_scores(context_tokens)
if overlap_scores is None:
return anchor_matches[:limit]
if not overlap_scores:
return []
best_overlap = max(overlap_scores.values()) if overlap_scores else 0.0
overlap_floor = max(0.16, best_overlap * 0.90)
focused_overlap_scores = {
sequence_index: overlap
for sequence_index, overlap in overlap_scores.items()
if overlap >= overlap_floor
}
if not focused_overlap_scores:
focused_overlap_scores = overlap_scores
focused_indices = set(focused_overlap_scores)
merged: dict[int, float] = {}
for similarity, sequence_index, _ in anchor_matches:
if sequence_index not in focused_indices:
continue
merged[sequence_index] = max(merged.get(sequence_index, 0.0), 0.20 * similarity)
for sequence_index, overlap in focused_overlap_scores.items():
merged[sequence_index] = merged.get(sequence_index, 0.0) + (0.80 * overlap)
ranked = [
(score, sequence_index, sequence_index)
for sequence_index, score in merged.items()
if score > 0.0
]
ranked.sort(key=lambda item: item[0], reverse=True)
return ranked[:limit]
def _answer_sequence_prompt_overlap_scores(
self,
context_tokens: list[str],
) -> dict[int, float] | None:
if (
self.embedding_model is None
or self.answer_sequence_prompt_tokens is None
or self.trace_token_weights is None
):
return None
answer_boundary = _last_index(context_tokens, "<answer>")
prompt_tokens = (
context_tokens[:answer_boundary]
if answer_boundary is not None
else context_tokens
)
if (
self.answer_sequence_prompt_specificity is None
and not self._defer_answer_sequence_prompt_overlap_cache()
):
self._refresh_answer_sequence_prompt_overlap_cache()
specificity_map = self.answer_sequence_prompt_specificity or {}
query_weights: dict[int, float] = {}
query_specificity: dict[int, float] = {}
query_segment_multipliers: dict[int, float] = {}
query_content_weight = 0.0
query_ids: list[int] = []
primary_query_ids: list[int] = []
inside_tool_evidence = False
prompt_segment_index = 0
for token in prompt_tokens:
if token in {"<tool_result>", "<source>"}:
inside_tool_evidence = True
continue
if token == "<final>":
inside_tool_evidence = False
continue
if self.tokenizer is not None and token in self.tokenizer.special_tokens:
continue
if self._is_structural_punctuation_token(token):
prompt_segment_index += 1
continue
if self._should_skip_prompt_overlap_token(token):
continue
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
continue
query_ids.append(token_id)
specificity = specificity_map.get(token_id, 1.0)
evidence_multiplier = 0.35 if inside_tool_evidence else 1.0
segment_multiplier = evidence_multiplier / (1.0 + prompt_segment_index)
weight = specificity * segment_multiplier
query_weights[token_id] = max(
query_weights.get(token_id, 0.0),
weight,
)
query_specificity[token_id] = max(
query_specificity.get(token_id, 0.0),
specificity,
)
query_segment_multipliers[token_id] = max(
query_segment_multipliers.get(token_id, 0.0),
segment_multiplier,
)
if not inside_tool_evidence:
primary_query_ids.append(token_id)
if specificity >= 0.20:
query_content_weight += weight
if not query_weights:
return None
full_query_token_ids = set(query_ids)
primary_query_token_ids = set(primary_query_ids)
has_tool_evidence = any(token in {"<tool_result>", "<source>"} for token in prompt_tokens)
query_norm = sum(value * value for value in query_weights.values()) ** 0.5
if query_norm <= 0.0:
return None
query_bigrams = {
(query_ids[index], query_ids[index + 1])
for index in range(len(query_ids) - 1)
}
query_trigrams = {
(query_ids[index], query_ids[index + 1], query_ids[index + 2])
for index in range(len(query_ids) - 2)
}
query_numbers = self._number_strings_from_tokens(prompt_tokens)
def ordered_ngram_score(
query_grams: set[tuple[int, ...]],
row_grams: set[tuple[int, ...]],
) -> float:
if not query_grams or not row_grams:
return 0.0
overlap = len(query_grams & row_grams)
if overlap <= 0:
return 0.0
return overlap / ((len(query_grams) * len(row_grams)) ** 0.5)
def prompt_length_fit(row_token_count: int) -> float:
query_token_count = len(full_query_token_ids)
if query_token_count <= 0 or row_token_count <= 0:
return 1.0
if row_token_count <= query_token_count:
return 1.0
extra_fraction = (row_token_count - query_token_count) / row_token_count
return max(0.25, 1.0 - extra_fraction)
cached_maps = self.answer_sequence_prompt_weight_maps
cached_norms = self.answer_sequence_prompt_weight_norms
cached_bigrams = self.answer_sequence_prompt_bigram_sets
cached_trigrams = self.answer_sequence_prompt_trigram_sets
cached_numbers = self.answer_sequence_prompt_number_sets
cached_index = self.answer_sequence_prompt_inverted_index
if (
cached_maps is not None
and cached_norms is not None
and cached_bigrams is not None
and cached_trigrams is not None
and cached_numbers is not None
and len(cached_maps) == len(self.answer_sequence_prompt_tokens)
):
candidate_indices: set[int] | range
if cached_index is not None:
candidates: set[int] = set()
ranked_query_ids = sorted(
query_weights,
key=lambda token_id: specificity_map.get(token_id, 1.0),
reverse=True,
)
distinctive_query_ids = [
token_id
for token_id in ranked_query_ids
if specificity_map.get(token_id, 1.0) >= 0.75
] or ranked_query_ids[:4]
for token_id in distinctive_query_ids:
candidates.update(cached_index.get(token_id, ()))
candidate_indices = candidates if candidates else range(len(cached_maps))
else:
candidate_indices = range(len(cached_maps))
candidate_indices = list(candidate_indices)
if cached_index is not None and candidate_indices:
candidate_set = set(candidate_indices)
local_query_weights: dict[int, float] = {}
local_query_specificity: dict[int, float] = {}
local_query_content_weight = 0.0
for token_id in query_weights:
local_frequency = len(candidate_set & set(cached_index.get(token_id, ())))
if local_frequency <= 0:
continue
specificity = self._prompt_overlap_token_specificity(
local_frequency,
len(candidate_indices),
)
weight = specificity * query_segment_multipliers.get(token_id, 1.0)
local_query_weights[token_id] = weight
local_query_specificity[token_id] = specificity
if specificity >= 0.20:
local_query_content_weight += weight
local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5
if local_query_norm > 0.0:
query_weights = local_query_weights
query_specificity = local_query_specificity
query_norm = local_query_norm
scores: dict[int, float] = {}
for sequence_index in candidate_indices:
row_weights = cached_maps[sequence_index]
if not row_weights:
continue
if query_numbers and not self._numeric_prompt_can_match(
query_numbers,
cached_numbers[sequence_index],
):
continue
matched_content_weight = sum(
query_weights[token_id]
for token_id in query_weights.keys() & row_weights.keys()
if query_specificity.get(token_id, 0.0) >= 0.20
)
row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max(
1,
len(row_weights),
)
full_query_coverage = len(full_query_token_ids & row_weights.keys()) / max(
1,
len(full_query_token_ids),
)
primary_query_coverage = len(primary_query_token_ids & row_weights.keys()) / max(
1,
len(primary_query_token_ids),
)
if (
has_tool_evidence
and len(primary_query_token_ids) >= 3
and primary_query_coverage < 0.45
and row_token_coverage < 0.75
):
continue
partial_query_floor = 0.60 if len(full_query_token_ids) < 8 else 0.50
if (
len(full_query_token_ids) >= 5
and full_query_coverage <= partial_query_floor
and row_token_coverage < 0.75
):
continue
if (
len(full_query_token_ids) >= 12
and full_query_coverage < 0.45
and row_token_coverage <= 0.75
):
continue
if (
query_content_weight > 0.0
and matched_content_weight / query_content_weight < 0.40
and row_token_coverage < 0.75
and full_query_coverage < 0.60
):
continue
query_coverage = (
matched_content_weight / query_content_weight
if query_content_weight > 0.0
else row_token_coverage
)
numerator = sum(
query_weights[token_id] * row_weights[token_id]
for token_id in query_weights.keys() & row_weights.keys()
)
if numerator <= 0.0:
continue
row_norm = cached_norms[sequence_index]
if row_norm <= 0.0:
continue
token_score = numerator / (query_norm * row_norm)
bigram_score = ordered_ngram_score(
query_bigrams,
cached_bigrams[sequence_index],
)
trigram_score = ordered_ngram_score(
query_trigrams,
cached_trigrams[sequence_index],
)
scores[sequence_index] = (
(0.35 * token_score)
+ (0.35 * query_coverage)
+ (0.15 * bigram_score)
+ (0.15 * trigram_score)
) * prompt_length_fit(len(row_weights))
return scores
vector_candidate_indices: list[int] | None = None
if cached_index is not None:
candidate_set: set[int] = set()
ranked_query_ids = sorted(
query_weights,
key=lambda token_id: specificity_map.get(token_id, 1.0),
reverse=True,
)
distinctive_query_ids = [
token_id
for token_id in ranked_query_ids
if specificity_map.get(token_id, 1.0) >= 0.75
] or ranked_query_ids[:4]
for token_id in distinctive_query_ids:
candidate_set.update(cached_index.get(token_id, ()))
if not candidate_set:
for token_id in ranked_query_ids:
candidate_set.update(cached_index.get(token_id, ()))
if candidate_set:
break
if not candidate_set:
candidate_indices = range(len(self.answer_sequence_prompt_tokens))
else:
candidate_indices = sorted(candidate_set)
local_query_weights: dict[int, float] = {}
local_query_specificity: dict[int, float] = {}
local_query_content_weight = 0.0
candidate_count = len(candidate_indices)
for token_id in query_weights:
local_frequency = len(candidate_set & set(cached_index.get(token_id, ())))
if local_frequency <= 0:
continue
specificity = self._prompt_overlap_token_specificity(
local_frequency,
candidate_count,
)
local_query_weights[token_id] = specificity * query_segment_multipliers.get(token_id, 1.0)
local_query_specificity[token_id] = specificity
if specificity >= 0.20:
local_query_content_weight += local_query_weights[token_id]
local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5
if local_query_norm > 0.0:
query_weights = local_query_weights
query_specificity = local_query_specificity
query_norm = local_query_norm
elif self._defer_answer_sequence_prompt_overlap_cache():
vector_candidate_indices = self._vector_answer_sequence_candidate_indices(
query_weights.keys()
)
if vector_candidate_indices is not None:
if not vector_candidate_indices:
return {}
candidate_indices = vector_candidate_indices
local_query_weights = {}
local_query_specificity = {}
local_query_content_weight = 0.0
candidate_count = len(vector_candidate_indices)
for token_id in query_weights:
local_frequency = self._vector_answer_sequence_local_frequency(
token_id,
vector_candidate_indices,
)
if local_frequency is None or local_frequency <= 0:
continue
specificity = self._prompt_overlap_token_specificity(
local_frequency,
candidate_count,
)
local_query_weights[token_id] = specificity * query_segment_multipliers.get(token_id, 1.0)
local_query_specificity[token_id] = specificity
if specificity >= 0.20:
local_query_content_weight += local_query_weights[token_id]
local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5
if local_query_norm > 0.0:
query_weights = local_query_weights
query_specificity = local_query_specificity
query_norm = local_query_norm
else:
candidate_indices = range(len(self.answer_sequence_prompt_tokens))
valid_token_mask = self._prompt_overlap_valid_token_mask()
scores: dict[int, float] = {}
for sequence_index in candidate_indices:
row = self.answer_sequence_prompt_tokens[sequence_index]
row_values = row.tolist() if hasattr(row, "tolist") else row
row_weights: dict[int, float] = {}
row_ids: list[int] = []
raw_row_ids: list[int] = []
for raw_token_id in row_values:
token_id = int(raw_token_id)
if token_id < 0 or token_id >= len(self.trace_token_weights):
continue
raw_row_ids.append(token_id)
if valid_token_mask is not None:
if token_id >= len(valid_token_mask) or not bool(valid_token_mask[token_id]):
continue
elif self._should_skip_prompt_overlap_token(
self.embedding_model.id_to_token[token_id]
):
continue
row_ids.append(token_id)
row_weights[token_id] = max(
row_weights.get(token_id, 0.0),
specificity_map.get(token_id, 1.0),
)
if not row_weights:
continue
if query_numbers and not self._numeric_prompt_can_match(
query_numbers,
self._number_strings_from_token_ids(raw_row_ids),
):
continue
matched_content_weight = sum(
query_weights[token_id]
for token_id in query_weights.keys() & row_weights.keys()
if query_specificity.get(token_id, 0.0) >= 0.20
)
row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max(
1,
len(row_weights),
)
full_query_coverage = len(full_query_token_ids & row_weights.keys()) / max(
1,
len(full_query_token_ids),
)
primary_query_coverage = len(primary_query_token_ids & row_weights.keys()) / max(
1,
len(primary_query_token_ids),
)
if (
has_tool_evidence
and len(primary_query_token_ids) >= 3
and primary_query_coverage < 0.45
and row_token_coverage < 0.75
):
continue
partial_query_floor = 0.60 if len(full_query_token_ids) < 8 else 0.30
if (
len(full_query_token_ids) >= 5
and full_query_coverage <= partial_query_floor
and row_token_coverage < 0.75
):
continue
if (
len(full_query_token_ids) >= 12
and full_query_coverage < 0.25
and row_token_coverage <= 0.75
):
continue
if (
query_content_weight > 0.0
and matched_content_weight / query_content_weight < 0.25
and row_token_coverage < 0.75
and full_query_coverage < 0.60
):
continue
query_coverage = (
matched_content_weight / query_content_weight
if query_content_weight > 0.0
else row_token_coverage
)
numerator = sum(
query_weights[token_id] * row_weights[token_id]
for token_id in query_weights.keys() & row_weights.keys()
)
if numerator <= 0.0:
continue
row_norm = sum(value * value for value in row_weights.values()) ** 0.5
if row_norm > 0.0:
token_score = numerator / (query_norm * row_norm)
row_bigrams = {
(row_ids[index], row_ids[index + 1])
for index in range(len(row_ids) - 1)
}
row_trigrams = {
(row_ids[index], row_ids[index + 1], row_ids[index + 2])
for index in range(len(row_ids) - 2)
}
bigram_score = ordered_ngram_score(query_bigrams, row_bigrams)
trigram_score = ordered_ngram_score(query_trigrams, row_trigrams)
scores[sequence_index] = (
(0.35 * token_score)
+ (0.35 * query_coverage)
+ (0.15 * bigram_score)
+ (0.15 * trigram_score)
) * prompt_length_fit(len(row_weights))
return scores
def _score_prompt_anchor_matches(
self,
answer_anchor_state: Vector | None,
keys: object | None,
key_norms_list: object | None,
values: object | None,
keys_array: object | None,
key_norms_array: object | None,
values_array: object | None,
valid_mask_array: object | None,
similarity_keys_array: object | None,
similarity_key_norms_array: object | None,
similarity_mask_array: object | None,
*,
limit: int,
) -> list[tuple[float, int, int]]:
if (
answer_anchor_state is None
or keys is None
or key_norms_list is None
or values is None
):
return []
if (
np is not None
and keys_array is not None
and key_norms_array is not None
and values_array is not None
and valid_mask_array is not None
and limit > 0
):
key_array = keys_array
key_norms = key_norms_array
if (
similarity_keys_array is not None
and similarity_key_norms_array is not None
and similarity_mask_array is not None
):
state_array = self._center_state_array(
self._masked_combined_state_array(answer_anchor_state)
).astype(keys_array.dtype, copy=False)
state_array = state_array * similarity_mask_array
key_array = similarity_keys_array
key_norms = similarity_key_norms_array
else:
state_array = self._center_state_array(answer_anchor_state).astype(
keys_array.dtype,
copy=False,
)
state_norm = float(np.linalg.norm(state_array))
if state_norm == 0.0:
return []
numerators = key_array @ state_array
denominators = key_norms * state_norm
valid_mask = valid_mask_array & (denominators > 0.0)
if np.any(valid_mask):
scores = np.zeros_like(numerators, dtype=key_array.dtype)
np.divide(numerators, denominators, out=scores, where=valid_mask)
positive_positions = np.flatnonzero(valid_mask & (scores > 0.0))
if positive_positions.size:
selected_positions = positive_positions
if positive_positions.size > limit:
partition = np.argpartition(scores[positive_positions], -limit)[-limit:]
selected_positions = positive_positions[partition]
ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]]
return [
(
float(scores[position]),
int(values_array[position]),
int(position),
)
for position in ordered_positions
]
if similarity_mask_array is not None:
state = self._center_state_vector(self._masked_combined_state(answer_anchor_state))
else:
state = self._center_state_vector(answer_anchor_state)
state_norm = norm(state)
if state_norm == 0.0:
return []
scored: list[tuple[float, int, int]] = []
for example_index, (key, key_norm, token_id) in enumerate(
zip(keys, key_norms_list, values)
):
if token_id < 0:
continue
denominator = state_norm * key_norm
if denominator == 0.0:
continue
similarity = dot(state, key) / denominator
if similarity > 0.0:
scored.append((similarity, token_id, example_index))
scored.sort(key=lambda item: item[0], reverse=True)
return scored[:limit]
def _answer_prior_from_matches(
self,
matches: list[tuple[float, int, int]],
generated_tokens: list[str],
) -> Vector:
assert self.embedding_model is not None
if not matches:
return [0.0 for _ in self.embedding_model.id_to_token]
prior = [0.0 for _ in self.embedding_model.id_to_token]
generated_ids = {
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
}
for similarity, token_id, _ in matches[:ANSWER_TOP_K]:
token = self.embedding_model.id_to_token[token_id]
if not self._allowed_generation_token(token, generated_tokens):
continue
if token_id in generated_ids:
prior[token_id] += similarity * 0.35
else:
prior[token_id] += similarity
return _normalize_vector(prior)
def _answer_start_matches_from_sequences(
self,
matches: list[tuple[float, int, int]],
) -> list[tuple[float, int, int]]:
if not matches or self.answer_sequence_tokens is None:
return []
start_matches: list[tuple[float, int, int]] = []
for similarity, sequence_index, example_index in matches[:ANSWER_START_TOP_K]:
if sequence_index >= len(self.answer_sequence_tokens):
continue
row = self.answer_sequence_tokens[sequence_index]
token_ids = [
int(value)
for value in (row.tolist() if hasattr(row, "tolist") else row)
if int(value) >= 0
]
if token_ids:
start_matches.append((similarity, token_ids[0], example_index))
return start_matches
def _answer_sequence_prior_from_matches(
self,
matches: list[tuple[float, int, int]],
generated_tokens: list[str],
*,
temperature: float = 0.0,
) -> Vector:
assert self.embedding_model is not None
if not matches or self.answer_sequence_tokens is None:
return [0.0 for _ in self.embedding_model.id_to_token]
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
prior = [0.0 for _ in self.embedding_model.id_to_token]
best_similarity = matches[0][0]
if best_similarity >= 0.9:
floor_delta = 0.14 if temperature >= ANSWER_SEQUENCE_CREATIVE_TEMPERATURE else 0.02
match_floor = best_similarity - floor_delta
else:
match_floor = 0.0
for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]:
if similarity < ANSWER_SEQUENCE_MATCH_FLOOR:
continue
if similarity < match_floor:
continue
token_ids = self._answer_sequence_token_row(sequence_index)
if not token_ids:
continue
next_token_id = self._next_sequence_token_id(token_ids, generated_ids)
if next_token_id is None:
continue
token = self.embedding_model.id_to_token[next_token_id]
if self._allowed_answer_sequence_token(token, generated_tokens):
prior[next_token_id] += max(1e-9, similarity - match_floor)
return _normalize_vector(prior)
def _answer_sequence_token_row(self, sequence_index: int) -> list[int]:
if sequence_index < 0 or self.answer_sequence_tokens is None:
return []
if self.answer_sequence_token_id_rows is not None:
if sequence_index >= len(self.answer_sequence_token_id_rows):
return []
return self.answer_sequence_token_id_rows[sequence_index]
if (
np is not None
and hasattr(self.answer_sequence_tokens, "shape")
and len(self.answer_sequence_tokens.shape) == 2
):
if sequence_index >= int(self.answer_sequence_tokens.shape[0]):
return []
row = np.asarray(self.answer_sequence_tokens[sequence_index])
return [int(value) for value in row.tolist() if int(value) >= 0]
try:
row = self.answer_sequence_tokens[sequence_index]
except (IndexError, TypeError):
return []
return self._answer_token_ids_from_row(row)
def _filter_avoided_answer_sequence_matches(
self,
matches: list[tuple[float, int, int]] | None,
avoid_token_sequences: Sequence[Sequence[str]] | None,
) -> list[tuple[float, int, int]]:
if (
not matches
or not avoid_token_sequences
or self.embedding_model is None
or self.answer_sequence_tokens is None
):
return list(matches or [])
token_to_id = self.embedding_model.token_to_id
avoided_id_sequences: set[tuple[int, ...]] = set()
for sequence in avoid_token_sequences:
ids: list[int] = []
for token in sequence:
token_id = token_to_id.get(token)
if token_id is None:
ids = []
break
ids.append(token_id)
if ids:
avoided_id_sequences.add(tuple(ids))
if not avoided_id_sequences:
return list(matches)
sequence_rows = self._answer_sequence_token_rows()
filtered: list[tuple[float, int, int]] = []
for match in matches:
_, sequence_index, _ = match
if sequence_index >= len(sequence_rows):
filtered.append(match)
continue
if tuple(sequence_rows[sequence_index]) in avoided_id_sequences:
continue
filtered.append(match)
return filtered
def _answer_sequence_token_rows(self) -> list[list[int]]:
if self.answer_sequence_token_id_rows is not None:
return self.answer_sequence_token_id_rows
rows: list[list[int]] = []
if (
np is not None
and self.answer_sequence_tokens is not None
and hasattr(self.answer_sequence_tokens, "shape")
and len(self.answer_sequence_tokens.shape) == 2
):
token_rows = np.asarray(self.answer_sequence_tokens).tolist()
rows = [
[int(value) for value in row if int(value) >= 0]
for row in token_rows
]
elif self.answer_sequence_tokens is not None:
for row in self.answer_sequence_tokens:
rows.append(self._answer_token_ids_from_row(row))
self.answer_sequence_token_id_rows = rows
return rows
@staticmethod
def _answer_token_ids_from_row(row: object) -> list[int]:
values = row.tolist() if hasattr(row, "tolist") else row
if not isinstance(values, list):
return []
return [int(value) for value in values if int(value) >= 0]
@staticmethod
def _answer_fingerprint_from_token_ids(token_ids: list[int]) -> tuple[int, ...]:
payload = ",".join(str(token_id) for token_id in token_ids).encode("ascii")
digest = hashlib.blake2s(
payload,
digest_size=ANSWER_FINGERPRINT_WORDS * 4,
).digest()
return tuple(
int.from_bytes(
digest[index * 4 : (index + 1) * 4],
"little",
signed=True,
)
for index in range(ANSWER_FINGERPRINT_WORDS)
)
def _refresh_answer_fingerprint_hashes(self) -> None:
hashes: set[tuple[int, ...]] = set()
lengths: set[int] = set()
sequences_by_length: dict[int, set[tuple[int, ...]]] = {}
if self.answer_sequence_tokens is not None:
for token_ids in self._answer_sequence_token_rows():
if token_ids:
token_length = len(token_ids)
lengths.add(token_length)
sequences_by_length.setdefault(token_length, set()).add(tuple(token_ids))
hashes.add(self._answer_fingerprint_from_token_ids(token_ids))
self.answer_fingerprint_hashes = hashes
self.answer_fingerprint_token_lengths = lengths
self.answer_fingerprint_token_sequences_by_length = sequences_by_length
def _answer_fingerprint_tensor(self) -> list[list[int]]:
if self.answer_fingerprint_hashes is None:
self._refresh_answer_fingerprint_hashes()
return [
list(fingerprint)
for fingerprint in sorted(self.answer_fingerprint_hashes or set())
]
@staticmethod
def _coerce_answer_fingerprint_hashes(raw_fingerprints: object) -> set[tuple[int, ...]]:
rows = raw_fingerprints.tolist() if hasattr(raw_fingerprints, "tolist") else raw_fingerprints
hashes: set[tuple[int, ...]] = set()
if not isinstance(rows, list):
return hashes
for row in rows:
values = row.tolist() if hasattr(row, "tolist") else row
if not isinstance(values, list):
continue
fingerprint = tuple(int(value) for value in values)
if len(fingerprint) == ANSWER_FINGERPRINT_WORDS:
hashes.add(fingerprint)
return hashes
def _answer_fingerprint_lengths(self) -> set[int]:
if self.answer_fingerprint_token_lengths is not None:
return self.answer_fingerprint_token_lengths
lengths: set[int] = set()
if (
np is not None
and self.answer_sequence_tokens is not None
and hasattr(self.answer_sequence_tokens, "shape")
and len(self.answer_sequence_tokens.shape) == 2
):
token_matrix = np.asarray(self.answer_sequence_tokens)
length_values = np.sum(token_matrix >= 0, axis=1)
lengths = {
int(length)
for length in np.unique(length_values).tolist()
if int(length) > 0
}
elif self.answer_sequence_tokens is not None:
for token_ids in self._answer_sequence_token_rows():
if token_ids:
lengths.add(len(token_ids))
self.answer_fingerprint_token_lengths = lengths
return lengths
def _use_runtime_fingerprint_blacklist(self) -> bool:
if (
np is None
or self.answer_sequence_tokens is None
or not hasattr(self.answer_sequence_tokens, "shape")
or len(self.answer_sequence_tokens.shape) != 2
):
return False
return int(self.answer_sequence_tokens.shape[0]) > ANSWER_SEQUENCE_EAGER_OVERLAP_CACHE_LIMIT
def _answer_fingerprint_token_sequence_sets(self) -> dict[int, set[tuple[int, ...]]]:
if self.answer_fingerprint_token_sequences_by_length is not None:
return self.answer_fingerprint_token_sequences_by_length
sequences_by_length: dict[int, set[tuple[int, ...]]] = {}
lengths: set[int] = set()
if self.answer_sequence_tokens is not None:
for token_ids in self._answer_sequence_token_rows():
if token_ids:
token_length = len(token_ids)
lengths.add(token_length)
sequences_by_length.setdefault(token_length, set()).add(tuple(token_ids))
self.answer_fingerprint_token_lengths = lengths
self.answer_fingerprint_token_sequences_by_length = sequences_by_length
return sequences_by_length
def _token_ids_for_generated_tokens(self, generated_tokens: Sequence[str]) -> list[int] | None:
if self.embedding_model is None:
return None
token_ids: list[int] = []
for token in generated_tokens:
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
return None
token_ids.append(token_id)
return token_ids
def _would_complete_blacklisted_answer(
self,
generated_tokens: list[str],
candidate: str,
) -> bool:
generated_token_ids = self._token_ids_for_generated_tokens(generated_tokens)
return self._would_complete_blacklisted_answer_ids(generated_token_ids, candidate)
def _would_complete_blacklisted_answer_ids(
self,
generated_token_ids: Sequence[int] | None,
candidate: str,
) -> bool:
if (
self.embedding_model is None
or not self.answer_fingerprint_hashes
or candidate not in self.embedding_model.token_to_id
or generated_token_ids is None
):
return False
candidate_id = self.embedding_model.token_to_id[candidate]
if self._is_terminal_punctuation_text(self._render_token(candidate)):
return False
candidate_length = len(generated_token_ids) + 1
if self._use_runtime_fingerprint_blacklist():
lengths = self._answer_fingerprint_lengths()
if lengths and candidate_length not in lengths:
return False
token_ids = [*generated_token_ids, candidate_id]
if not token_ids:
return False
return self._answer_fingerprint_from_token_ids(token_ids) in self.answer_fingerprint_hashes
sequence_sets = self._answer_fingerprint_token_sequence_sets()
candidate_sequences = sequence_sets.get(candidate_length)
if candidate_sequences is not None:
return (*generated_token_ids, candidate_id) in candidate_sequences
if self.answer_sequence_tokens is not None:
return False
lengths = self._answer_fingerprint_lengths()
if lengths and candidate_length not in lengths:
return False
token_ids = [*generated_token_ids, candidate_id]
if not token_ids:
return False
return self._answer_fingerprint_from_token_ids(token_ids) in self.answer_fingerprint_hashes
def _would_follow_blacklisted_answer_prefix_ids(
self,
generated_token_ids: Sequence[int] | None,
candidate: str,
*,
minimum_prefix_length: int = ANSWER_REPLAY_PREFIX_MIN_TOKENS,
) -> bool:
if (
self.embedding_model is None
or self.answer_sequence_tokens is None
or candidate not in self.embedding_model.token_to_id
or generated_token_ids is None
):
return False
candidate_id = self.embedding_model.token_to_id[candidate]
candidate_path = (*generated_token_ids, candidate_id)
if len(candidate_path) < minimum_prefix_length:
return False
prefix_sets = self._answer_sequence_prefix_sets(minimum_prefix_length)
return candidate_path in prefix_sets.get(len(candidate_path), set())
def _answer_sequence_prefix_sets(
self,
minimum_prefix_length: int = ANSWER_REPLAY_PREFIX_MIN_TOKENS,
) -> dict[int, set[tuple[int, ...]]]:
cached = self.answer_sequence_prefixes_by_length
if cached is not None:
return cached
prefixes: dict[int, set[tuple[int, ...]]] = {}
for token_ids in self._answer_sequence_token_rows():
for length in range(minimum_prefix_length, len(token_ids) + 1):
prefixes.setdefault(length, set()).add(tuple(token_ids[:length]))
self.answer_sequence_prefixes_by_length = prefixes
return prefixes
def _avoid_text_token_sequences(
self,
avoid_texts: Sequence[str] | None,
) -> list[list[str]]:
if not avoid_texts or self.tokenizer is None:
return []
sequences: list[list[str]] = []
seen: set[tuple[str, ...]] = set()
for text in avoid_texts:
if not isinstance(text, str) or not text.strip():
continue
tokens = [
token
for token in self.tokenizer.encode(text)
if token not in self.tokenizer.special_tokens
]
key = tuple(tokens)
if tokens and key not in seen:
seen.add(key)
sequences.append(tokens)
return sequences
@staticmethod
def _runtime_generation_history_key(context: str) -> str:
return " ".join(context.split()).casefold()
@staticmethod
def _runtime_history_enabled(context: str, *, temperature: float) -> bool:
if temperature < ANSWER_REPLAY_PREFIX_TEMPERATURE:
return False
lowered = context.casefold()
return "<source>" not in lowered and "<tool_result>" not in lowered
def _runtime_avoid_texts(
self,
context: str,
avoid_texts: Sequence[str] | None,
*,
temperature: float,
) -> list[str]:
combined: list[str] = []
seen: set[str] = set()
for text in avoid_texts or ():
cleaned = " ".join(str(text).split())
if cleaned and cleaned not in seen:
combined.append(cleaned)
seen.add(cleaned)
if not self._runtime_history_enabled(context, temperature=temperature):
return combined
history = self.runtime_generation_history.get(
self._runtime_generation_history_key(context),
[],
)
for text in history:
cleaned = " ".join(str(text).split())
if cleaned and cleaned not in seen:
combined.append(cleaned)
seen.add(cleaned)
return combined
def _remember_runtime_generation(
self,
context: str,
generated_text: str,
*,
temperature: float,
) -> None:
if not self._runtime_history_enabled(context, temperature=temperature):
return
cleaned = " ".join(generated_text.split())
if not cleaned:
return
key = self._runtime_generation_history_key(context)
history = [
existing
for existing in self.runtime_generation_history.get(key, [])
if existing != cleaned
]
history.append(cleaned)
self.runtime_generation_history[key] = history[-RUNTIME_GENERATION_HISTORY_LIMIT:]
@staticmethod
def _would_follow_avoided_sequence(
generated_tokens: list[str],
candidate: str,
avoid_token_sequences: Sequence[Sequence[str]] | None,
) -> bool:
if not avoid_token_sequences:
return False
prefix_length = len(generated_tokens) + 1
if prefix_length < AVOID_SEQUENCE_MIN_TOKENS:
return False
candidate_path = [*generated_tokens, candidate]
for sequence in avoid_token_sequences:
if prefix_length <= len(sequence) and list(sequence[:prefix_length]) == candidate_path:
return True
return False
def _should_stop_answer_sequence(
self,
decode_state: DecodeState,
generated_tokens: list[str],
) -> bool:
matches = decode_state.answer_sequence_matches
if matches is None:
matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
)
return self._answer_sequence_is_complete(generated_tokens, matches)
def _should_stop_after_answer_path_drift(
self,
decode_state: DecodeState,
generated_tokens: list[str],
) -> bool:
matches = decode_state.answer_sequence_matches
if matches is None:
matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
)
if not matches or matches[0][0] < ANSWER_SEQUENCE_MATCH_FLOOR:
return False
if self._answer_sequence_has_continuation(generated_tokens, matches):
return False
if self._generated_answer_ends_terminal_sentence(generated_tokens):
return True
return self._generated_word_count(generated_tokens) >= 14
def _generated_answer_ends_terminal_sentence(self, generated_tokens: list[str]) -> bool:
if not generated_tokens:
return False
rendered = self._render_token(generated_tokens[-1])
if not self._is_terminal_punctuation_text(rendered):
return False
return self._generated_word_count(generated_tokens) > 0
def _answer_decode_has_continuation(
self,
decode_state: DecodeState,
generated_tokens: list[str],
) -> bool:
matches = decode_state.answer_sequence_matches
if matches is None:
matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
)
return self._answer_sequence_has_continuation(generated_tokens, matches)
def _answer_sequence_is_complete(
self,
generated_tokens: list[str],
matches: list[tuple[float, int, int]],
) -> bool:
if (
self.embedding_model is None
or self.answer_sequence_tokens is None
or not generated_tokens
or not matches
):
return False
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
if not generated_ids:
return False
for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]:
if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens):
continue
row = self.answer_sequence_tokens[sequence_index]
token_ids = [
int(value)
for value in (row.tolist() if hasattr(row, "tolist") else row)
if int(value) >= 0
]
if not token_ids:
continue
if len(generated_ids) >= len(token_ids) and generated_ids[: len(token_ids)] == token_ids:
return True
if (
self.answer_fingerprint_hashes
and len(generated_ids) + 1 == len(token_ids)
and generated_ids == token_ids[: len(generated_ids)]
and self._answer_fingerprint_from_token_ids(token_ids)
in self.answer_fingerprint_hashes
):
generated_tail = self._render_token(generated_tokens[-1])
if self._is_structural_punctuation_text(
generated_tail
) and not self._is_terminal_punctuation_text(generated_tail):
continue
final_token = self.embedding_model.id_to_token[token_ids[-1]]
if self._is_terminal_punctuation_text(self._render_token(final_token)):
continue
return True
return False
def _answer_sequence_has_continuation(
self,
generated_tokens: list[str],
matches: list[tuple[float, int, int]],
) -> bool:
if (
self.embedding_model is None
or self.answer_sequence_tokens is None
or not generated_tokens
or not matches
):
return False
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
if not generated_ids:
return False
for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]:
if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens):
continue
row = self.answer_sequence_tokens[sequence_index]
token_ids = [
int(value)
for value in (row.tolist() if hasattr(row, "tolist") else row)
if int(value) >= 0
]
if not token_ids:
continue
next_token_id = self._next_sequence_token_id(token_ids, generated_ids)
if next_token_id is None:
continue
token = self.embedding_model.id_to_token[next_token_id]
if self._allowed_answer_sequence_token(token, generated_tokens):
return True
return False
def _next_sequence_token_id(
self,
token_ids: list[int],
generated_ids: list[int],
) -> int | None:
if not generated_ids:
return token_ids[0]
if len(generated_ids) >= len(token_ids):
return None
if token_ids[: len(generated_ids)] != generated_ids:
return None
return token_ids[len(generated_ids)]
def _transition_prior(self, context_tokens: list[str]) -> Vector:
prior, _ = self._transition_prior_with_order(context_tokens)
return prior
def _transition_prior_with_order(
self,
context_tokens: list[str],
) -> tuple[Vector, int | None]:
assert self.embedding_model is not None
if self.transition_id_tables:
for order in TRANSITION_ORDERS:
if len(context_tokens) < order:
continue
key_ids: list[int] = []
for token in context_tokens[-order:]:
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
key_ids = []
break
key_ids.append(token_id)
if not key_ids:
continue
transitions = self._transition_tensor_lookup(order, key_ids)
if transitions is None:
transitions = self.transition_id_tables.get(order, {}).get(tuple(key_ids))
if not transitions:
continue
next_token_ids, probabilities = transitions
prior = [0.0 for _ in self.embedding_model.id_to_token]
for token_id, probability in zip(next_token_ids, probabilities):
token_index = int(token_id)
if 0 <= token_index < len(prior):
prior[token_index] = float(probability)
return _normalize_vector(prior), order
if not self.transition_tables:
return [0.0 for _ in self.embedding_model.id_to_token], None
for order in TRANSITION_ORDERS:
if len(context_tokens) < order:
continue
key = tuple(context_tokens[-order:])
transitions = self.transition_tables.get(order, {}).get(key)
if not transitions:
continue
prior = [0.0 for _ in self.embedding_model.id_to_token]
for token, probability in transitions.items():
token_id = self.embedding_model.token_to_id.get(token)
if token_id is not None:
prior[token_id] = probability
return _normalize_vector(prior), order
return [0.0 for _ in self.embedding_model.id_to_token], None
def _transition_prior_array_with_order(
self,
context_tokens: list[str],
) -> tuple[object, int | None]:
assert np is not None
assert self.embedding_model is not None
prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
if self.transition_id_tables:
for order in TRANSITION_ORDERS:
if len(context_tokens) < order:
continue
key_ids: list[int] = []
for token in context_tokens[-order:]:
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
key_ids = []
break
key_ids.append(token_id)
if not key_ids:
continue
transitions = self._transition_tensor_lookup(order, key_ids)
if transitions is None:
transitions = self.transition_id_tables.get(order, {}).get(tuple(key_ids))
if not transitions:
continue
next_token_ids, probabilities = transitions
token_ids_array = np.asarray(next_token_ids, dtype=np.int64)
probabilities_array = np.asarray(probabilities, dtype=np.float64)
valid = (
(token_ids_array >= 0)
& (token_ids_array < len(self.embedding_model.id_to_token))
& (probabilities_array > 0.0)
)
if np.any(valid):
prior[token_ids_array[valid]] = probabilities_array[valid]
total = float(prior.sum())
if total > 0.0:
prior /= total
return prior, order
return prior, None
if not self.transition_tables:
return prior, None
for order in TRANSITION_ORDERS:
if len(context_tokens) < order:
continue
key = tuple(context_tokens[-order:])
transitions = self.transition_tables.get(order, {}).get(key)
if not transitions:
continue
for token, probability in transitions.items():
token_id = self.embedding_model.token_to_id.get(token)
if token_id is not None:
prior[token_id] = probability
total = float(prior.sum())
if total > 0.0:
prior /= total
return prior, order
return prior, None
def _copy_prior(self, context_tokens: list[str]) -> Vector:
assert self.embedding_model is not None
assert self.tokenizer is not None
prior = [0.0 for _ in self.embedding_model.id_to_token]
decay = 0.82
answer_start = None
for index in range(len(context_tokens) - 1, -1, -1):
if context_tokens[index] == "<answer>":
answer_start = index + 1
break
source_tokens = (
context_tokens[: max(0, answer_start - 1)]
if answer_start is not None
else context_tokens
)
if not source_tokens:
return prior
for distance, token in enumerate(reversed(source_tokens)):
if token in self.tokenizer.special_tokens:
continue
if not self._eligible_copy_token(token):
continue
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
continue
prior[token_id] += (decay**distance) * self._copy_token_distinctiveness(token)
return _normalize_vector(prior)
def _copy_prior_array(self, context_tokens: list[str]) -> object:
assert np is not None
assert self.embedding_model is not None
assert self.tokenizer is not None
prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
decay = 0.82
answer_start = None
for index in range(len(context_tokens) - 1, -1, -1):
if context_tokens[index] == "<answer>":
answer_start = index + 1
break
source_tokens = (
context_tokens[: max(0, answer_start - 1)]
if answer_start is not None
else context_tokens
)
for distance, token in enumerate(reversed(source_tokens)):
if token in self.tokenizer.special_tokens:
continue
if not self._eligible_copy_token(token):
continue
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
continue
prior[token_id] += (decay**distance) * self._copy_token_distinctiveness(token)
total = float(prior.sum())
if total > 0.0:
prior /= total
return prior
def _copy_token_distinctiveness(self, token: str) -> float:
rendered = self._render_token(token).strip()
if not rendered:
return 0.0
letters = sum(character.isalpha() for character in rendered)
digits = sum(character.isdigit() for character in rendered)
symbols = sum(
not character.isalnum() and not character.isspace()
for character in rendered
)
score = 1.0
if any(character.isupper() for character in rendered) and letters:
score += 0.8
if digits:
score += 0.9
if symbols:
score += 0.5
if len(rendered) >= 4:
score += 0.2
return score
def _prompt_copy_evidence_is_distinctive(self, context_tokens: list[str]) -> bool:
answer_start = None
for index in range(len(context_tokens) - 1, -1, -1):
if context_tokens[index] == "<answer>":
answer_start = index
break
prompt_tokens = context_tokens[:answer_start] if answer_start is not None else context_tokens
for token in prompt_tokens:
if self.tokenizer is not None and token in self.tokenizer.special_tokens:
continue
rendered = self._render_token(token).strip()
if any(character.isdigit() for character in rendered):
return True
if sum(character.isupper() for character in rendered) >= 2:
return True
return False
def _source_evidence_prior(
self,
context_tokens: list[str],
generated_tokens: list[str] | None = None,
) -> Vector:
assert self.embedding_model is not None
prior = [0.0 for _ in self.embedding_model.id_to_token]
for token_id, weight in self._source_evidence_token_weights(
context_tokens,
generated_tokens or [],
).items():
if 0 <= token_id < len(prior):
prior[token_id] += weight
return _normalize_vector(prior)
def _source_evidence_prior_array(
self,
context_tokens: list[str],
generated_tokens: list[str] | None = None,
) -> object:
assert np is not None
assert self.embedding_model is not None
prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
for token_id, weight in self._source_evidence_token_weights(
context_tokens,
generated_tokens or [],
).items():
if 0 <= token_id < prior.size:
prior[token_id] += weight
total = float(prior.sum())
if total > 0.0:
prior /= total
return prior
def _source_evidence_token_weights(
self,
context_tokens: list[str],
generated_tokens: list[str],
) -> dict[int, float]:
if self.embedding_model is None or self.tokenizer is None:
return {}
segments = self._source_evidence_segments(context_tokens)
if not segments:
return {}
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
first_source_index = _first_index(context_tokens, "<source>")
query_tokens = (
context_tokens[:first_source_index]
if first_source_index is not None
else context_tokens
)
query_token_ids = {
self.embedding_model.token_to_id[token]
for token in query_tokens
if token in self.embedding_model.token_to_id
and token not in self.tokenizer.special_tokens
and self._eligible_copy_token(token)
}
weights: dict[int, float] = {}
def add_token(token: str, weight: float, *, allow_piece: bool = False) -> None:
if token in self.tokenizer.special_tokens:
return
if not allow_piece and not self._allowed_generation_token(token, generated_tokens):
return
if allow_piece:
rendered = self._render_token(token)
if not rendered or not rendered.strip():
return
elif not self._eligible_copy_token(token):
return
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
return
weights[token_id] = weights.get(token_id, 0.0) + weight
for segment_tokens, segment_weight, segment_role in segments[-6:]:
if generated_ids and segment_role != "snippet":
continue
token_ids = [
self.embedding_model.token_to_id[token]
for token in segment_tokens
if token in self.embedding_model.token_to_id
]
aligned = False
if generated_ids and token_ids:
max_suffix = min(8, len(generated_ids), len(token_ids))
for suffix_length in range(max_suffix, 0, -1):
suffix = generated_ids[-suffix_length:]
for index in range(len(token_ids) - suffix_length):
if token_ids[index : index + suffix_length] != suffix:
continue
next_token_id = token_ids[index + suffix_length]
next_token = self.embedding_model.id_to_token[next_token_id]
add_token(
next_token,
segment_weight * (3.0 + suffix_length),
allow_piece=True,
)
aligned = True
if aligned:
break
if aligned:
continue
content_rank = 0
anchor_seen = False
segment_has_query_anchor = any(token_id in query_token_ids for token_id in token_ids)
for token in segment_tokens:
rendered = self._render_token(token)
if "://" in rendered or rendered.casefold().startswith("http"):
continue
if not self._eligible_copy_token(token):
continue
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
continue
if segment_has_query_anchor:
in_query = token_id in query_token_ids
if in_query:
weight = segment_weight * 0.42
anchor_seen = True
elif anchor_seen:
weight = segment_weight * 2.10
else:
weight = segment_weight * 0.32
elif content_rank == 0:
weight = segment_weight * 4.0
elif content_rank == 1:
weight = segment_weight * 1.35
else:
weight = segment_weight * 0.65
weight *= 0.94 ** min(content_rank, 24)
add_token(token, weight)
content_rank += 1
return weights
def _source_evidence_segments(self, context_tokens: list[str]) -> list[tuple[list[str], float, str]]:
if self.tokenizer is None:
return []
answer_boundary = _last_index(context_tokens, "<answer>")
upper_bound = answer_boundary if answer_boundary is not None else len(context_tokens)
boundary_tokens = {"<source>", "<tool_result>", "<tool_call>", "<final>", "<answer>"}
segments: list[tuple[list[str], float, str]] = []
index = 0
while index < upper_bound:
if context_tokens[index] != "<source>":
index += 1
continue
start = index + 1
end = start
while (
end < upper_bound
and context_tokens[end] not in boundary_tokens
and self._render_token(context_tokens[end]) != "\n"
):
end += 1
source_tokens = context_tokens[start:end]
pipe_positions = [
position
for position, token in enumerate(source_tokens)
if self._render_token(token).strip() == "|"
]
if pipe_positions:
snippet_tokens = source_tokens[pipe_positions[-1] + 1 :]
if snippet_tokens:
segments.append((snippet_tokens, 1.0, "snippet"))
elif source_tokens:
segments.append((source_tokens, 0.90, "snippet"))
index = end + 1
return segments
def _source_evidence_is_complete(
self,
context_tokens: list[str],
generated_tokens: list[str],
) -> bool:
if (
self.embedding_model is None
or self.tokenizer is None
or self._generated_word_count(generated_tokens) < 5
):
return False
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
if not generated_ids:
return False
for segment_tokens, _, segment_role in self._source_evidence_segments(context_tokens):
if segment_role != "snippet":
continue
segment_ids = [
self.embedding_model.token_to_id[token]
for token in segment_tokens
if token in self.embedding_model.token_to_id
]
if len(generated_ids) > len(segment_ids):
continue
max_suffix = min(12, len(generated_ids), len(segment_ids))
for suffix_length in range(max_suffix, 4, -1):
suffix_ids = generated_ids[-suffix_length:]
for start in range(len(segment_ids) - suffix_length + 1):
if segment_ids[start : start + suffix_length] != suffix_ids:
continue
next_index = start + suffix_length
if next_index >= len(segment_ids):
return True
next_token = self.embedding_model.id_to_token[segment_ids[next_index]]
if self._source_punctuation_continues_numeric_span(
segment_ids,
next_index,
):
return False
if self._is_terminal_punctuation_text(self._render_token(next_token)):
return True
return False
def _source_evidence_has_continuation(
self,
context_tokens: list[str],
generated_tokens: list[str],
) -> bool:
if self.embedding_model is None or not generated_tokens:
return False
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
if not generated_ids:
return False
for segment_tokens, _, segment_role in self._source_evidence_segments(context_tokens):
if segment_role != "snippet":
continue
segment_ids = [
self.embedding_model.token_to_id[token]
for token in segment_tokens
if token in self.embedding_model.token_to_id
]
max_suffix = min(12, len(generated_ids), len(segment_ids))
for suffix_length in range(max_suffix, 0, -1):
suffix_ids = generated_ids[-suffix_length:]
for start in range(len(segment_ids) - suffix_length + 1):
if segment_ids[start : start + suffix_length] != suffix_ids:
continue
next_index = start + suffix_length
if next_index >= len(segment_ids):
return False
if self._source_punctuation_continues_numeric_span(
segment_ids,
next_index,
) or self._source_punctuation_continues_numeric_span(
segment_ids,
next_index - 1,
):
return True
next_token = self.embedding_model.id_to_token[segment_ids[next_index]]
return not self._is_terminal_punctuation_text(
self._render_token(next_token)
)
return False
def _source_evidence_next_token(
self,
context_tokens: list[str],
generated_tokens: list[str],
) -> str | None:
if self.embedding_model is None:
return None
for segment_tokens, _, segment_role in self._source_evidence_segments(context_tokens):
if segment_role != "snippet" or not segment_tokens:
continue
if not generated_tokens:
return segment_tokens[0]
segment_ids = [
self.embedding_model.token_to_id[token]
for token in segment_tokens
if token in self.embedding_model.token_to_id
]
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
if not segment_ids or not generated_ids:
continue
max_suffix = min(12, len(generated_ids), len(segment_ids))
for suffix_length in range(max_suffix, 0, -1):
suffix_ids = generated_ids[-suffix_length:]
for start in range(len(segment_ids) - suffix_length + 1):
if segment_ids[start : start + suffix_length] != suffix_ids:
continue
next_index = start + suffix_length
if next_index < len(segment_ids):
return self.embedding_model.id_to_token[segment_ids[next_index]]
return None
def _source_punctuation_continues_numeric_span(
self,
segment_ids: list[int],
punctuation_index: int,
) -> bool:
if self.embedding_model is None:
return False
if punctuation_index <= 0 or punctuation_index + 1 >= len(segment_ids):
return False
punctuation_text = self._render_token(
self.embedding_model.id_to_token[segment_ids[punctuation_index]]
).strip()
if not self._is_structural_punctuation_text(punctuation_text):
return False
previous_text = self._render_token(
self.embedding_model.id_to_token[segment_ids[punctuation_index - 1]]
)
next_text = self._render_token(
self.embedding_model.id_to_token[segment_ids[punctuation_index + 1]]
)
return any(character.isdigit() for character in previous_text) and any(
character.isdigit() for character in next_text
)
def _preference_prior(self) -> Vector:
assert self.embedding_model is not None
if not self.preference_bias or not any(value != 0.0 for value in self.preference_bias):
return [0.0 for _ in self.embedding_model.id_to_token]
eligible_indices = [
index
for index, token in enumerate(self.embedding_model.id_to_token)
if self.preference_bias[index] > 0.0 and self._eligible_preference_token(token)
]
if not eligible_indices:
return [0.0 for _ in self.embedding_model.id_to_token]
eligible_probabilities = self._calibrated_softmax(
[self.preference_bias[index] for index in eligible_indices]
)
prior = [0.0 for _ in self.embedding_model.id_to_token]
for index, probability in zip(eligible_indices, eligible_probabilities):
prior[index] = probability
return prior
def _preference_prior_array(self) -> object:
assert np is not None
assert self.embedding_model is not None
if self.preference_bias_array is None or not np.any(self.preference_bias_array != 0.0):
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
if self.preference_valid_mask_array is None or not np.any(self.preference_valid_mask_array):
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
positive_mask = self.preference_bias_array > 0.0
active_mask = self.preference_valid_mask_array & positive_mask
if not np.any(active_mask):
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
prior[active_mask] = self._calibrated_softmax_array(
self.preference_bias_array[active_mask]
)
return prior
def _eligible_preference_token(self, token: str) -> bool:
assert self.tokenizer is not None
if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens:
return False
if not self._starts_new_word(token):
return False
rendered = self._render_token(token)
if not rendered.strip() or self._is_punctuation_piece(rendered):
return False
alphanumeric = "".join(character for character in rendered if character.isalnum())
return len(alphanumeric) >= 1
def _build_transition_tables(
self,
tokens: list[str],
) -> dict[int, dict[tuple[str, ...], dict[str, float]]]:
counts: dict[int, dict[tuple[str, ...], dict[str, int]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
for order in sorted(TRANSITION_ORDERS):
for index in range(order - 1, len(tokens) - 1):
key = tuple(tokens[index - order + 1 : index + 1])
nxt = tokens[index + 1]
bucket = counts[order].setdefault(key, {})
bucket[nxt] = bucket.get(nxt, 0) + 1
probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
for order, mapping in counts.items():
items = list(mapping.items())
items.sort(key=lambda item: (-sum(item[1].values()), item[0]))
if (
self.config.max_transition_contexts_per_order is not None
and self.config.max_transition_contexts_per_order >= 0
):
items = items[: self.config.max_transition_contexts_per_order]
for key, bucket in items:
next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0]))
if self.config.max_transition_next_tokens > 0:
next_items = next_items[: self.config.max_transition_next_tokens]
total = sum(value for _, value in next_items)
if total <= 0:
continue
probabilities[order][key] = {
token: value / total
for token, value in next_items
}
return probabilities
def _transition_table_tensors(self) -> dict[str, object]:
assert self.embedding_model is not None
if self.transition_tensor_cache is not None:
return {
"transition_orders": self.transition_tensor_cache["orders"],
"transition_key_offsets": self.transition_tensor_cache["key_offsets"],
"transition_key_token_ids": self.transition_tensor_cache["key_token_ids"],
"transition_next_offsets": self.transition_tensor_cache["next_offsets"],
"transition_next_token_ids": self.transition_tensor_cache["next_token_ids"],
"transition_next_probabilities": self.transition_tensor_cache["next_probabilities"],
}
if not self.transition_tables:
return {
"transition_orders": [],
"transition_key_offsets": [0],
"transition_key_token_ids": [],
"transition_next_offsets": [0],
"transition_next_token_ids": [],
"transition_next_probabilities": [],
}
token_to_id = self.embedding_model.token_to_id
orders: list[int] = []
key_offsets: list[int] = [0]
key_token_ids: list[int] = []
next_offsets: list[int] = [0]
next_token_ids: list[int] = []
next_probabilities: list[float] = []
for order in sorted(self.transition_tables):
mapping = self.transition_tables.get(order, {})
for key, transitions in mapping.items():
key_ids = [token_to_id.get(token, -1) for token in key]
if len(key_ids) != order or any(token_id < 0 for token_id in key_ids):
continue
next_items = [
(token_to_id[token], float(probability))
for token, probability in transitions.items()
if token in token_to_id and probability > 0.0
]
if not next_items:
continue
orders.append(order)
key_token_ids.extend(key_ids)
key_offsets.append(len(key_token_ids))
for token_id, probability in next_items:
next_token_ids.append(token_id)
next_probabilities.append(probability)
next_offsets.append(len(next_token_ids))
return {
"transition_orders": orders,
"transition_key_offsets": key_offsets,
"transition_key_token_ids": key_token_ids,
"transition_next_offsets": next_offsets,
"transition_next_token_ids": next_token_ids,
"transition_next_probabilities": next_probabilities,
}
def _deserialize_transition_id_tables_from_tensors(
self,
tensors: dict[str, object],
) -> dict[int, dict[tuple[int, ...], tuple[object, object]]] | None:
required = (
"transition_orders",
"transition_key_offsets",
"transition_key_token_ids",
"transition_next_offsets",
"transition_next_token_ids",
"transition_next_probabilities",
)
if any(name not in tensors for name in required):
return None
def _as_sequence(name: str) -> object:
value = tensors.get(name, [])
return value if hasattr(value, "shape") else list(value)
orders = _as_sequence("transition_orders")
key_offsets = _as_sequence("transition_key_offsets")
key_token_ids = _as_sequence("transition_key_token_ids")
next_offsets = _as_sequence("transition_next_offsets")
next_token_ids = _as_sequence("transition_next_token_ids")
next_probabilities = _as_sequence("transition_next_probabilities")
row_count = len(orders)
if row_count == 0:
return {order: {} for order in sorted(TRANSITION_ORDERS)}
if len(key_offsets) != row_count + 1 or len(next_offsets) != row_count + 1:
return None
if np is not None and hasattr(orders, "shape"):
self.transition_tensor_cache = {
"orders": orders,
"key_offsets": key_offsets,
"key_token_ids": key_token_ids,
"next_offsets": next_offsets,
"next_token_ids": next_token_ids,
"next_probabilities": next_probabilities,
"order_spans": {},
}
self.transition_built_orders = set()
return {order: {} for order in sorted(TRANSITION_ORDERS)}
tables: dict[int, dict[tuple[int, ...], tuple[object, object]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
for index in range(row_count):
order = int(orders[index])
key_start = int(key_offsets[index])
key_end = int(key_offsets[index + 1])
next_start = int(next_offsets[index])
next_end = int(next_offsets[index + 1])
key = tuple(int(token_id) for token_id in key_token_ids[key_start:key_end])
if len(key) != order or next_end <= next_start:
continue
tables.setdefault(order, {})[key] = (
next_token_ids[next_start:next_end],
next_probabilities[next_start:next_end],
)
return tables
def _serialize_transition_tables(self) -> dict[str, dict[str, dict[str, float]]]:
assert self.transition_tables is not None
return {
str(order): {
_encode_ngram_key(key): value
for key, value in mapping.items()
}
for order, mapping in self.transition_tables.items()
}
def _deserialize_transition_tables(
self,
payload: dict[str, dict[str, dict[str, float]]],
) -> dict[int, dict[tuple[str, ...], dict[str, float]]]:
tables: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
for order_text, mapping in payload.items():
order = int(order_text)
tables[order] = {
_decode_ngram_key(key): {
str(token): float(probability)
for token, probability in value.items()
}
for key, value in mapping.items()
}
return tables
def _transition_tensor_order_span(self, order: int) -> tuple[int, int] | None:
if np is None or self.transition_tensor_cache is None:
return None
spans = self.transition_tensor_cache.get("order_spans")
if isinstance(spans, dict) and order in spans:
return spans[order]
orders = self.transition_tensor_cache["orders"]
positions = np.flatnonzero(orders == order)
span = (
(int(positions[0]), int(positions[-1]) + 1)
if positions.size
else None
)
if isinstance(spans, dict):
spans[order] = span
return span
def _transition_tensor_lookup(
self,
order: int,
key_ids: list[int],
) -> tuple[object, object] | None:
if (
np is None
or self.transition_tensor_cache is None
or len(key_ids) != order
):
return None
span = self._transition_tensor_order_span(order)
if span is None:
return None
row_start, row_end = span
key_offsets = self.transition_tensor_cache["key_offsets"]
key_token_ids = self.transition_tensor_cache["key_token_ids"]
next_offsets = self.transition_tensor_cache["next_offsets"]
next_token_ids = self.transition_tensor_cache["next_token_ids"]
next_probabilities = self.transition_tensor_cache["next_probabilities"]
key_start = int(key_offsets[row_start])
key_end = int(key_offsets[row_end])
key_block = np.asarray(key_token_ids[key_start:key_end], dtype=np.int64)
row_count = row_end - row_start
if row_count <= 0 or key_block.size != row_count * order:
return None
keys = key_block.reshape(row_count, order)
query = np.asarray(key_ids, dtype=np.int64)
matches = np.flatnonzero(np.all(keys == query[None, :], axis=1))
if not matches.size:
return None
row = row_start + int(matches[0])
next_start = int(next_offsets[row])
next_end = int(next_offsets[row + 1])
if next_end <= next_start:
return None
return (
next_token_ids[next_start:next_end],
next_probabilities[next_start:next_end],
)
def _eligible_copy_token(self, token: str) -> bool:
rendered = self._render_token(token)
if not rendered.strip():
return False
if self._is_punctuation_piece(rendered):
return False
if not self._starts_new_word(token):
return False
alphanumeric = "".join(character for character in rendered if character.isalnum())
return len(alphanumeric) >= 2
def _allowed_generation_token(
self,
token: str,
generated_tokens: list[str],
context_tokens: list[str] | None = None,
) -> bool:
return self._allowed_generation_token_with_meta(
token,
self._generation_token_meta(token),
generated_tokens,
context_tokens,
)
def _allowed_generation_token_with_meta(
self,
token: str,
meta: GenerationTokenMeta,
generated_tokens: list[str],
context_tokens: list[str] | None = None,
) -> bool:
assert self.embedding_model is not None
assert self.tokenizer is not None
if token == self.tokenizer.unk_token:
return False
if token in self.tokenizer.special_tokens:
return self._allowed_tool_protocol_token(
token,
generated_tokens=generated_tokens,
context_tokens=context_tokens or [],
)
if len(self.embedding_model.id_to_token) < 1024:
return True
if meta.rendered == "\n":
return bool(generated_tokens)
if not meta.stripped:
return False
if meta.word_joiner:
return (
self._can_attach_word_joiner(generated_tokens)
or self._can_start_line_with_word_joiner(token, generated_tokens)
)
if meta.structural_punctuation:
return bool(generated_tokens) or self._can_start_answer_with_structural_punctuation(token)
if meta.structural_symbol:
return bool(generated_tokens) or meta.starts_new_word
if not meta.starts_new_word:
if not generated_tokens:
return False
previous_rendered = self._render_token(generated_tokens[-1])
return (
bool(previous_rendered)
and any(character.isalnum() for character in previous_rendered)
and bool(meta.alphanumeric)
)
return len(meta.alphanumeric) >= 1 or not meta.punctuation_piece
@staticmethod
def _allowed_tool_protocol_token(
token: str,
*,
generated_tokens: list[str],
context_tokens: list[str],
) -> bool:
if token not in TOOL_PROTOCOL_TOKENS:
return False
if token == "<tool_call>":
return (
ReframrModel._context_requests_tool_call(context_tokens)
and
"<tool_call>" not in generated_tokens
and "<tool_result>" not in generated_tokens
and "<source>" not in generated_tokens
)
if token in {"<tool_result>", "<source>"}:
return False
if token == "<final>":
return (
"<tool_result>" in context_tokens
or "<source>" in context_tokens
or "<final>" in context_tokens
)
return True
@staticmethod
def _context_requests_tool_call(context_tokens: list[str]) -> bool:
rendered_terms: list[str] = []
for token in context_tokens:
if token in TOOL_PROTOCOL_TOKENS or token.startswith("<"):
continue
normalized = token.replace("▁", " ").strip().casefold()
if not normalized:
continue
rendered_terms.append(normalized)
pieces = {
"".join(
character
for character in piece
if character.isalnum() or character in {"-", "."}
)
for piece in normalized.split()
}
if pieces & TOOL_CALL_CONTEXT_TERMS:
return True
joined = " ".join(rendered_terms)
compact = "".join(rendered_terms)
return any(
term in joined or term.replace("-", "") in compact
for term in TOOL_CALL_CONTEXT_TERMS
)
def _would_repeat_recent_pattern(
self,
candidate: str,
generated_tokens: list[str],
recent_rendered_words: list[str] | None = None,
) -> bool:
if len(generated_tokens) >= 2 and generated_tokens[-1] == candidate and generated_tokens[-2] == candidate:
return True
if len(generated_tokens) >= 2:
trigram = tuple(generated_tokens[-2:] + [candidate])
recent_tokens = generated_tokens[-12:]
for index in range(max(0, len(recent_tokens) - 4)):
if tuple(recent_tokens[index : index + 3]) == trigram:
return True
rendered_words = recent_rendered_words
if rendered_words is None:
rendered_words = self._recent_rendered_words(generated_tokens)
candidate_meta = self._generation_token_meta(candidate)
candidate_word = candidate_meta.rendered.casefold()
if (
rendered_words
and candidate_meta.starts_new_word
and any(character.isalnum() for character in candidate_word)
):
candidate_bigram = (rendered_words[-1], candidate_word)
recent_window = rendered_words[-10:]
recent_bigrams = {
(recent_window[index], recent_window[index + 1])
for index in range(len(recent_window) - 1)
}
if candidate_bigram in recent_bigrams:
return True
if (
len(candidate_word) > 2
and rendered_words[-10:].count(candidate_word) >= 2
and not candidate_meta.common_connector
):
return True
return False
@staticmethod
def _is_inside_tool_protocol_continuation(generated_tokens: list[str]) -> bool:
return any(token in TOOL_PROTOCOL_TOKENS for token in generated_tokens[-6:])
def _would_repeat_recent_phrase(
self,
candidate: str,
generated_tokens: list[str],
*,
recent_rendered_words: list[str] | None = None,
) -> bool:
if not self._starts_new_word(candidate):
return False
rendered_words = list(
recent_rendered_words
if recent_rendered_words is not None
else self._recent_rendered_words(generated_tokens)
)
candidate_word = self._render_token(candidate).casefold()
if not any(character.isalnum() for character in candidate_word):
return False
rendered_words.append(candidate_word)
recent_window = rendered_words[-48:]
for span in range(4, min(8, len(recent_window)) + 1):
suffix = tuple(recent_window[-span:])
earlier = recent_window[:-span]
for index in range(len(earlier) - span + 1):
if tuple(earlier[index : index + span]) == suffix:
return True
return False
def _recent_phrase_repeat_candidate_words(
self,
recent_rendered_words: list[str],
) -> set[str]:
repeat_candidates: set[str] = set()
base_window = recent_rendered_words[-47:]
max_span = min(8, len(base_window) + 1)
if max_span < 4:
return repeat_candidates
for span in range(4, max_span + 1):
prefix_length = span - 1
suffix_prefix = tuple(base_window[-prefix_length:])
earlier_length = len(base_window) - prefix_length
if earlier_length < span:
continue
for index in range(earlier_length - span + 1):
earlier_segment = base_window[index : index + span]
if tuple(earlier_segment[:-1]) == suffix_prefix:
candidate_word = earlier_segment[-1]
if any(character.isalnum() for character in candidate_word):
repeat_candidates.add(candidate_word)
return repeat_candidates
def _recent_rendered_words(self, generated_tokens: list[str]) -> list[str]:
rendered_words: list[str] = []
for token in generated_tokens:
if not self._starts_new_word(token):
continue
rendered = self._render_token(token).casefold()
if any(character.isalnum() for character in rendered):
rendered_words.append(rendered)
return rendered_words
def _select_generation_token(
self,
distribution: dict[str, float],
*,
context_tokens: list[str] | None = None,
generated_tokens: list[str] | None = None,
temperature: float = DEFAULT_GENERATION_TEMPERATURE,
top_k: int = DEFAULT_GENERATION_TOP_K,
top_p: float = DEFAULT_GENERATION_TOP_P,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
preserve_dominant_candidates: bool = False,
avoid_token_sequences: Sequence[Sequence[str]] | None = None,
) -> str:
assert self.tokenizer is not None
generated_tokens = generated_tokens or []
candidates = self._prepare_generation_candidates(
distribution,
context_tokens=context_tokens or [],
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=preserve_dominant_candidates,
avoid_token_sequences=avoid_token_sequences,
)
if candidates:
return self._sample_generation_candidate(
candidates,
context_tokens=context_tokens or [],
generated_tokens=generated_tokens,
stochastic=temperature > 0.0,
preserve_dominant_candidates=preserve_dominant_candidates,
)
for token, _ in sorted(distribution.items(), key=lambda item: item[1], reverse=True):
if token in self.tokenizer.special_tokens and token not in TOOL_PROTOCOL_TOKENS:
continue
if token == self.tokenizer.unk_token:
continue
if not self._allowed_generation_token(token, generated_tokens, context_tokens or []):
continue
if self._would_complete_blacklisted_answer(generated_tokens, token):
continue
return token
return ""
def _select_generation_token_from_array(
self,
probabilities: object,
*,
context_tokens: list[str],
generated_tokens: list[str],
temperature: float = DEFAULT_GENERATION_TEMPERATURE,
top_k: int = DEFAULT_GENERATION_TOP_K,
top_p: float = DEFAULT_GENERATION_TOP_P,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
preserve_dominant_candidates: bool = False,
avoid_token_sequences: Sequence[Sequence[str]] | None = None,
) -> str:
assert np is not None
assert self.tokenizer is not None
assert self.embedding_model is not None
values = np.asarray(probabilities, dtype=np.float64)
if values.size == 0:
return ""
first_pool_size = min(values.size, max(top_k, 64))
if first_pool_size <= 0:
first_pool_size = min(values.size, 64)
expanded_pool_size = min(values.size, max(top_k * 4, 64))
pool_sizes: list[int] = []
for pool_size in (first_pool_size, expanded_pool_size, values.size):
if pool_size > 0 and pool_size not in pool_sizes:
pool_sizes.append(pool_size)
for pool_size in pool_sizes:
if pool_size < values.size:
candidate_indices = np.argpartition(values, -pool_size)[-pool_size:]
candidate_indices = candidate_indices[np.argsort(values[candidate_indices])[::-1]]
else:
candidate_indices = np.argsort(values)[::-1]
distribution: dict[str, float] = {}
for raw_index in candidate_indices:
index = int(raw_index)
score = float(values[index])
if score <= 0.0:
continue
token = self.embedding_model.id_to_token[index]
if (
token == self.tokenizer.unk_token
or token in self.tokenizer.special_tokens
and token not in TOOL_PROTOCOL_TOKENS
):
continue
distribution[token] = score
selected = self._select_generation_token(
distribution,
context_tokens=context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=preserve_dominant_candidates,
avoid_token_sequences=avoid_token_sequences,
)
if selected:
return selected
return ""
def _prepare_generation_candidates(
self,
distribution: dict[str, float],
*,
context_tokens: list[str] | None = None,
generated_tokens: list[str],
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
preserve_dominant_candidates: bool = False,
avoid_token_sequences: Sequence[Sequence[str]] | None = None,
) -> list[tuple[str, float]]:
assert self.tokenizer is not None
assert self.embedding_model is not None
context_tokens = context_tokens or []
generated_word_count = self._generated_word_count(generated_tokens)
clause_words = self._words_since_clause_break(generated_tokens)
recent_rendered_words = self._recent_rendered_words(generated_tokens)
generated_token_ids = self._token_ids_for_generated_tokens(generated_tokens)
inside_tool_protocol = self._is_inside_tool_protocol_continuation(generated_tokens)
phrase_repeat_candidate_words = (
self._recent_phrase_repeat_candidate_words(recent_rendered_words)
if generated_word_count >= MIN_COMPLETE_ANSWER_WORDS and not inside_tool_protocol
else set()
)
prompt_content_tokens = [
token
for token in context_tokens
if token not in self.tokenizer.special_tokens
and self._generation_token_meta(token).starts_new_word
and self._generation_token_meta(token).alphanumeric
and not self._generation_token_meta(token).punctuation_piece
]
initial_prompt_content_token = (
prompt_content_tokens[0]
if len(prompt_content_tokens) > 1
else None
)
best_probability = max(distribution.values(), default=0.0)
has_uppercase_start_candidate = any(
probability > 0.0
and self._generation_token_meta(token).starts_new_word
and self._generation_token_meta(token).rendered[:1].isupper()
for token, probability in distribution.items()
)
adjusted: list[tuple[str, float]] = []
for token, probability in sorted(distribution.items(), key=lambda item: item[1], reverse=True):
if token in self.tokenizer.special_tokens and token not in TOOL_PROTOCOL_TOKENS:
continue
if token == self.tokenizer.unk_token or probability <= 0.0:
continue
meta = self._generation_token_meta(token)
allowed_by_general_filter = self._allowed_generation_token_with_meta(
token,
meta,
generated_tokens,
context_tokens,
)
if not allowed_by_general_filter:
dominant_learned_continuation = (
preserve_dominant_candidates
and best_probability > 0.0
and probability >= best_probability * 0.99
and self._allowed_answer_sequence_token(token, generated_tokens)
)
if not dominant_learned_continuation:
continue
if self._would_complete_blacklisted_answer_ids(generated_token_ids, token):
continue
repeats_recent_pattern = self._would_repeat_recent_pattern(
token,
generated_tokens,
recent_rendered_words=recent_rendered_words,
)
hard_phrase_loop = (
generated_word_count >= MIN_COMPLETE_ANSWER_WORDS
and not inside_tool_protocol
and meta.starts_new_word
and meta.rendered.casefold() in phrase_repeat_candidate_words
)
if hard_phrase_loop:
continue
if repeats_recent_pattern:
dominant_candidate_allowed = (
preserve_dominant_candidates
and best_probability > 0.0
and probability >= best_probability * 0.80
)
if not dominant_candidate_allowed:
continue
score = probability
if (
temperature >= ANSWER_REPLAY_PREFIX_TEMPERATURE
and not inside_tool_protocol
and self._would_follow_blacklisted_answer_prefix_ids(
generated_token_ids,
token,
)
):
score *= ANSWER_REPLAY_PREFIX_PENALTY
if (
temperature > 0.0
and self._would_follow_avoided_sequence(
generated_tokens,
token,
avoid_token_sequences,
)
):
score *= 0.12
rendered = meta.rendered
punctuation_token = meta.structural_punctuation
starts_new_word = meta.starts_new_word
alphanumeric = meta.alphanumeric
if (
not generated_tokens
and initial_prompt_content_token is not None
and token == initial_prompt_content_token
):
dominant_answer_candidate = (
preserve_dominant_candidates
and best_probability > 0.0
and probability >= best_probability * 0.80
)
if not dominant_answer_candidate:
continue
if (
not generated_tokens
and temperature > 0.0
and has_uppercase_start_candidate
and starts_new_word
and rendered[:1].islower()
and best_probability > 0.0
and probability < best_probability * 0.85
):
continue
if generated_tokens and starts_new_word and alphanumeric:
previous_alphanumeric = self._generation_token_meta(
generated_tokens[-1]
).alphanumeric
if previous_alphanumeric.casefold() == alphanumeric.casefold():
continue
common_connector = meta.common_connector
if (
starts_new_word
and len(alphanumeric) == 1
and not common_connector
):
score *= 0.08
recent_count = generated_tokens[-12:].count(token)
if recent_count > 0 and not common_connector:
score /= repetition_penalty ** (2 * recent_count)
if generated_tokens and token == generated_tokens[-1]:
score /= repetition_penalty**3
if generated_tokens and token in generated_tokens[-4:] and not common_connector:
score *= 0.35
if generated_tokens and not starts_new_word and self._starts_new_word(generated_tokens[-1]):
score *= 0.08
if not generated_tokens and punctuation_token:
if best_probability <= 0.0 or probability < best_probability * 0.80:
score *= 0.01
elif not generated_tokens and not starts_new_word:
score *= 0.02
if (
not generated_tokens
and temperature > 0.0
and has_uppercase_start_candidate
and starts_new_word
and rendered[:1].islower()
):
score *= 0.03
if punctuation_token:
if generated_tokens and self._is_structural_punctuation_token(generated_tokens[-1]):
score *= 0.05
if clause_words >= 6:
score *= 1.0 + min(1.4, 0.18 * (clause_words - 5))
elif generated_word_count >= 12:
score *= 1.1
if score > 0.0:
adjusted.append((token, score))
if not adjusted:
return []
adjusted.sort(key=lambda item: item[1], reverse=True)
if preserve_dominant_candidates:
top_score = adjusted[0][1]
second_score = adjusted[1][1] if len(adjusted) > 1 else 0.0
if top_score >= 0.5 and (
second_score <= 0.0
or top_score >= second_score * 1.2
or top_score - second_score >= 0.08
):
return [(adjusted[0][0], 1.0)]
effective_top_k = top_k
if (
temperature >= CREATIVE_EARLY_POOL_TEMPERATURE
and generated_word_count < CREATIVE_EARLY_POOL_WORD_LIMIT
and not inside_tool_protocol
and top_k > CREATIVE_EARLY_POOL_MAX
):
effective_top_k = CREATIVE_EARLY_POOL_MAX
if effective_top_k > 0:
adjusted = adjusted[:effective_top_k]
if 0.0 < top_p < 1.0:
kept: list[tuple[str, float]] = []
cumulative = 0.0
total = sum(score for _, score in adjusted)
for token, score in adjusted:
normalized = score / total if total else 0.0
kept.append((token, score))
cumulative += normalized
if cumulative >= top_p:
break
adjusted = kept
if temperature <= 0.0:
return [(adjusted[0][0], 1.0)]
exponent = 1.0 / temperature
tempered = [
(token, score**exponent)
for token, score in adjusted
if score > 0.0
]
total = sum(score for _, score in tempered)
if total <= 0.0:
return []
return [(token, score / total) for token, score in tempered]
def _sample_generation_candidate(
self,
candidates: list[tuple[str, float]],
*,
context_tokens: list[str],
generated_tokens: list[str],
stochastic: bool = False,
preserve_dominant_candidates: bool = False,
) -> str:
if not candidates:
return ""
if len(candidates) == 1:
return candidates[0][0]
top_probability = candidates[0][1]
second_probability = candidates[1][1]
top_has_clear_half_majority = top_probability >= 0.5 and (
second_probability <= 0.0
or top_probability - second_probability >= 0.02
)
if preserve_dominant_candidates and top_has_clear_half_majority:
return candidates[0][0]
decisive_stochastic_winner = stochastic and (
top_probability >= 0.985
or (
top_probability >= 0.96
and second_probability > 0.0
and top_probability >= second_probability * 20.0
)
or (
top_probability >= 0.90
and second_probability > 0.0
and top_probability >= second_probability * 40.0
)
or (
top_probability >= 0.90
and top_probability - second_probability >= 0.75
)
)
decisive_deterministic_winner = not stochastic and (
top_has_clear_half_majority
or (second_probability > 0.0 and top_probability >= second_probability * 2.5)
or (
top_probability >= 0.08
and second_probability > 0.0
and top_probability >= second_probability * 1.35
)
)
if decisive_stochastic_winner or decisive_deterministic_winner:
return candidates[0][0]
if stochastic:
threshold = random.random()
else:
seed_payload = "\u0002".join([*context_tokens, "<generated>", *generated_tokens, str(len(candidates))])
seed = int.from_bytes(hashlib.sha256(seed_payload.encode("utf-8")).digest()[:8], "big")
threshold = random.Random(seed).random()
cumulative = 0.0
for token, probability in candidates:
cumulative += probability
if threshold <= cumulative:
return token
return candidates[-1][0]
def _top_entries_from_vector(
self,
values: Vector,
limit: int,
) -> list[dict[str, object]]:
if limit <= 0:
return []
ranked = sorted(
enumerate(values),
key=lambda item: item[1],
reverse=True,
)
return [
self._token_entry(index, probability)
for index, probability in ranked[:limit]
if probability > 0.0
]
def _token_entry(
self,
index: int,
probability: float,
) -> dict[str, object]:
assert self.embedding_model is not None
token = self.embedding_model.id_to_token[index]
return {
"token": token,
"text": self._render_token(token),
"probability": probability,
}
def _build_reasoning_summary(
self,
transition_order: int | None,
blend_weights: dict[str, float],
) -> str:
dominant_source = max(blend_weights.items(), key=lambda item: item[1])[0] if blend_weights else "base"
if transition_order is not None:
transition_message = f" Transition prior is using order-{transition_order} context."
else:
transition_message = " Transition prior found no matching n-gram."
return (
"Generation is running on analytical state, recurrent traces, and corpus-derived token transitions."
f"{transition_message}"
f" Dominant blend source: {dominant_source}."
)
def _generated_word_count(self, tokens: list[str]) -> int:
count = 0
for token in tokens:
rendered = self._render_token(token)
if not any(character.isalnum() for character in rendered):
continue
if self._starts_new_word(token) or count == 0:
count += 1
return count
def _is_structural_punctuation_text(self, text: str) -> bool:
if len(text) != 1:
return False
if self._is_word_joiner_text(text):
return False
category = unicodedata.category(text)
return category.startswith("P")
def _is_structural_punctuation_token(self, token: str) -> bool:
return self._is_structural_punctuation_text(self._render_token(token))
def _is_structural_symbol_token(self, token: str) -> bool:
rendered = self._render_token(token)
return len(rendered) == 1 and unicodedata.category(rendered).startswith("S")
def _is_word_joiner_token(self, token: str) -> bool:
return self._is_word_joiner_text(self._render_token(token))
def _is_word_joiner_text(self, text: str) -> bool:
if len(text) != 1:
return False
category = unicodedata.category(text)
if category in ("Pc", "Pd", "Lm"):
return True
name = unicodedata.name(text, "")
return "APOSTROPHE" in name or (
"SINGLE" in name and "QUOTATION MARK" in name
)
def _can_start_line_with_word_joiner(self, token: str, generated_tokens: list[str]) -> bool:
rendered = self._render_token(token)
if len(rendered) != 1 or unicodedata.category(rendered) != "Pd":
return False
if not self._starts_new_word(token):
return False
return not generated_tokens or self._render_token(generated_tokens[-1]) == "\n"
def _can_start_answer_with_structural_punctuation(self, token: str) -> bool:
rendered = self._render_token(token)
if len(rendered) != 1 or not self._starts_new_word(token):
return False
return unicodedata.category(rendered) in ("Ps", "Pi")
def _is_common_connector_token(self, token: str) -> bool:
rendered = self._render_token(token)
return rendered.isalpha() and len(rendered) == 1 and rendered.islower()
def _can_attach_word_joiner(self, generated_tokens: list[str]) -> bool:
if not generated_tokens:
return False
rendered = self._render_token(generated_tokens[-1])
if not rendered:
return False
if any(character.isalnum() for character in rendered):
return True
if len(rendered) != 1:
return False
return unicodedata.category(rendered) in ("Ps", "Pi")
def _words_since_clause_break(self, tokens: list[str]) -> int:
assert self.tokenizer is not None
words = 0
for token in reversed(tokens):
if token in self.tokenizer.special_tokens:
continue
rendered = self._render_token(token)
if self._is_structural_punctuation_text(rendered):
break
if self._starts_new_word(token) and not self._is_punctuation_piece(rendered):
words += 1
return words
def _should_stop_generation(self, generated_tokens: list[str]) -> bool:
if not generated_tokens:
return False
if not self._is_terminal_punctuation_text(self._render_token(generated_tokens[-1])):
return False
word_count = self._generated_word_count(generated_tokens)
if word_count >= MIN_COMPLETE_ANSWER_WORDS:
return True
return (
word_count >= MIN_COMPLETE_MULTI_SENTENCE_WORDS
and self._terminal_sentence_count(generated_tokens) >= 2
)
def _terminal_sentence_count(self, tokens: list[str]) -> int:
return sum(
1
for token in tokens
if self._is_terminal_punctuation_text(self._render_token(token))
)
def _is_terminal_punctuation_text(self, text: str) -> bool:
stripped = text.strip()
if not stripped:
return False
terminal_character = stripped[-1]
if not self._is_structural_punctuation_text(terminal_character):
return False
return not self._is_word_joiner_text(terminal_character)
def _should_skip_prompt_overlap_token(self, token: str) -> bool:
rendered = self._render_token(token)
if not rendered.strip():
return True
if (
self.embedding_model is not None
and len(self.embedding_model.id_to_token) >= 1024
and not self._starts_new_word(token)
):
return True
if self._is_structural_punctuation_text(rendered):
return True
return rendered.strip().casefold() in PROMPT_ENVELOPE_TERMS
def _starts_new_word(self, token: str) -> bool:
assert self.tokenizer is not None
if token in self.tokenizer.special_tokens:
return True
if token.startswith(self.tokenizer.word_prefix):
return True
return len(token) == 1 and not token.isalnum() and not self._is_word_joiner_token(token)
def _generation_token_meta(self, token: str) -> GenerationTokenMeta:
cache = self.generation_token_meta_cache
if cache is None:
cache = {}
self.generation_token_meta_cache = cache
cached = cache.get(token)
if cached is not None:
return cached
rendered = self._render_token(token)
meta = GenerationTokenMeta(
rendered=rendered,
stripped=rendered.strip(),
starts_new_word=self._starts_new_word(token),
punctuation_piece=self._is_punctuation_piece(rendered),
structural_punctuation=self._is_structural_punctuation_token(token),
structural_symbol=self._is_structural_symbol_token(token),
word_joiner=self._is_word_joiner_token(token),
alphanumeric="".join(character for character in rendered if character.isalnum()),
common_connector=self._is_common_connector_token(token),
)
cache[token] = meta
return meta
def _decode_tokens(self, tokens: list[str]) -> str:
assert self.tokenizer is not None
return self.tokenizer.decode(
tokens,
preserve_special_tokens=TOOL_PROTOCOL_TOKENS,
)
@staticmethod
def _normalize_generated_tool_protocol_text(text: str, *, context: str | None = None) -> str:
marker = "<tool_call>"
call_index = text.find(marker)
if call_index < 0:
return text
cleaned = text[:]
for boundary in ("<tool_result>", "<source>", "<final>"):
boundary_index = cleaned.find(boundary, call_index + len(marker))
if boundary_index >= 0:
cleaned = cleaned[:boundary_index].rstrip()
second_call_index = cleaned.find(marker, call_index + len(marker))
if second_call_index >= 0:
cleaned = cleaned[:second_call_index].rstrip()
brace_start = cleaned.find("{", call_index)
if brace_start < 0:
return cleaned.strip()
depth = 0
in_string = False
escaped = False
last_top_level_comma: int | None = None
for index in range(brace_start, len(cleaned)):
character = cleaned[index]
if escaped:
escaped = False
continue
if in_string and character == "\\":
escaped = True
continue
if character == '"':
in_string = not in_string
continue
if in_string:
continue
if character == "{":
depth += 1
continue
if character == "}":
depth -= 1
if depth <= 0:
candidate = cleaned[: index + 1].strip()
return ReframrModel._repair_tool_call_payload_if_needed(
candidate,
context=context,
)
continue
if character == "," and depth == 1:
last_top_level_comma = index
if depth > 0:
if last_top_level_comma is not None:
candidate = cleaned[:last_top_level_comma].rstrip() + "}"
return ReframrModel._repair_tool_call_payload_if_needed(
candidate,
context=context,
)
candidate = cleaned.rstrip() + "}"
return ReframrModel._repair_tool_call_payload_if_needed(
candidate,
context=context,
)
return ReframrModel._repair_tool_call_payload_if_needed(
cleaned.strip(),
context=context,
)
@staticmethod
def _repair_tool_call_payload_if_needed(text: str, *, context: str | None = None) -> str:
marker = "<tool_call>"
if not text.startswith(marker):
return text
brace_start = text.find("{", len(marker))
if brace_start < 0:
return text
tool_name = text[len(marker) : brace_start].strip()
payload_text = text[brace_start:].strip()
try:
payload = json.loads(payload_text)
if isinstance(payload, dict) and tool_name == "web.search":
repaired_query = ReframrModel._repair_search_query_from_context_if_weak(
str(payload.get("query", "")),
context,
)
if repaired_query is not None:
payload["query"] = repaired_query
return f"{marker} {tool_name} {json.dumps(payload, ensure_ascii=False)}"
return text
except (TypeError, json.JSONDecodeError):
pass
body = payload_text.strip()
if body.startswith("{"):
body = body[1:]
if body.endswith("}"):
body = body[:-1]
body = " ".join(body.replace('"', "").split())
if not tool_name or not body:
return text
if tool_name == "web.search":
payload = {
"query": ReframrModel._repair_search_query_from_context_if_weak(
body,
context,
)
or body
}
else:
payload = {"input": body}
return f"{marker} {tool_name} {json.dumps(payload, ensure_ascii=False)}"
@staticmethod
def _repair_search_query_from_context_if_weak(
query: str,
context: str | None,
) -> str | None:
cleaned_query = " ".join(query.replace("{", " ").replace("}", " ").split())
normalized_words = [
word.strip(" \t\r\n:,.;!?\"'()[]{}").casefold()
for word in cleaned_query.split()
if word.strip(" \t\r\n:,.;!?\"'()[]{}")
]
unique_content_words = {
word
for word in normalized_words
if word not in {"query", "web.search", "tool_call"}
}
lowered_query = cleaned_query.casefold()
weak = (
len(unique_content_words) < 3
or lowered_query.startswith("query:")
or "web.search" in lowered_query
or any(
marker in lowered_query
for marker in ("<tool", "<source>", "<final>", "according to")
)
)
if not weak:
return None
context_query = ReframrModel._search_query_from_context(context or "")
return context_query or None
@staticmethod
def _search_query_from_context(context: str) -> str:
if not context:
return ""
before_tool_result = context.split("<tool_result>", 1)[0]
before_final = before_tool_result.split("<final>", 1)[0]
lines = [line.strip() for line in before_final.splitlines() if line.strip()]
if not lines:
lines = [before_final.strip()]
latest_user = ""
for line in lines:
lowered = line.casefold()
if lowered.startswith("user:"):
latest_user = line.split(":", 1)[1].strip()
elif lowered.startswith("question:"):
latest_user = line.split(":", 1)[1].strip()
if not latest_user:
latest_user = lines[-1]
for prefix in ("User:", "Question:", "Prompt:", "Context:"):
if latest_user.casefold().startswith(prefix.casefold()):
latest_user = latest_user[len(prefix) :].strip()
cleaned = " ".join(latest_user.split())
return cleaned.strip(" \t\r\n\"'")
@staticmethod
def _finalize_generated_text(text: str) -> str:
stripped = text.rstrip()
if not stripped:
return stripped
if stripped.startswith("<tool_call>"):
return stripped
stripped = ReframrModel._remove_separator_punctuation_before_boundary(stripped)
if stripped and ReframrModel._is_separator_punctuation(stripped[-1:]):
stripped = stripped[:-1].rstrip()
if not stripped:
return stripped
if (
ReframrModel._is_surface_punctuation(stripped[:1])
or ReframrModel._is_surface_punctuation(stripped[-1:])
):
return stripped
if any(character.isalnum() for character in stripped[-8:]):
return f"{stripped}."
return stripped
@staticmethod
def _remove_separator_punctuation_before_boundary(text: str) -> str:
cleaned: list[str] = []
for character in text:
if (
ReframrModel._is_separator_punctuation(character)
and cleaned
and ReframrModel._is_separator_punctuation(cleaned[-1])
):
cleaned.pop()
cleaned.append(character)
return "".join(cleaned)
@staticmethod
def _is_surface_punctuation(character: str) -> bool:
return len(character) == 1 and unicodedata.category(character).startswith("P")
@staticmethod
def _is_separator_punctuation(character: str) -> bool:
return (
ReframrModel._is_surface_punctuation(character)
and unicodedata.bidirectional(character) == "CS"
)
def _render_token(self, token: str) -> str:
assert self.tokenizer is not None
if token.startswith(self.tokenizer.word_prefix):
return token[len(self.tokenizer.word_prefix) :]
return token
def _require_fit(self) -> None:
if (
self.tokenizer is None
or self.embedding_model is None
or self.memory_units is None
or self.readout_weights is None
or self.ternary_mask is None
or self.associative_keys is None
or (
self.associative_key_norms is None
and self.associative_key_norms_array is None
)
or self.associative_values is None
or self.transition_tables is None
):
raise RuntimeError("Call fit() before using the REFRAMR model.")
def _ensure_numeric_caches(self) -> None:
if np is None:
return
if self.readout_weights_array is None:
self._refresh_numeric_caches()