OkeyMeta commited on
Commit
2147ce8
·
verified ·
1 Parent(s): d641a1b

Release Reframr-RFM-v1-Base public checkpoint

Browse files

Public v1 base release for Reframr RFM. Internal provenance: v95 computed checkpoint. Includes model.safetensors, tokenizer, runtime source, config, generation examples, and model card.

README.md ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - reframr
6
+ - okeymeta
7
+ - non-transformer
8
+ - recurrent-memory
9
+ - computed-weights
10
+ - cpu-inference
11
+ - safetensors
12
+ library_name: reframr
13
+ pipeline_tag: text-generation
14
+ license: other
15
+ base_model: scratch
16
+ ---
17
+
18
+ # Reframr-RFM-v1-Base
19
+
20
+ **Reframr-RFM-v1-Base** is the first public base checkpoint from **OkeyMeta Ltd** for the Reframr line of non-Transformer language models. Reframr is built from scratch around recurrent memory, computed weights, and data-derived structure rather than a Transformer attention stack.
21
+
22
+ This release is packaged as `model.safetensors` with the matching `tokenizer.json`, runtime source, config, and runnable examples. A larger production Reframr line is being computed after this release, including tool-use and web-freshness data.
23
+
24
+ ## What It Is
25
+
26
+ Reframr-RFM means **Recurrent Flow Memory**. The model is designed around a persistent recurrent state instead of a fixed quadratic attention map. That gives the architecture no fixed attention-window context limit; practical limits are determined by runtime session length, machine memory, and deployment policy.
27
+
28
+ This checkpoint is not a Transformer, not a fine-tuned clone of a Transformer, and not a prompt wrapper. It uses the Reframr runtime included in this repository and a checkpoint kind of `reframr-analytical`.
29
+
30
+ ## Model Files
31
+
32
+ - `model.safetensors`: Reframr v1 computed-weight checkpoint.
33
+ - `tokenizer.json`: FrameToken tokenizer exported from the checkpoint metadata.
34
+ - `config.json`: Release metadata and tensor layout.
35
+ - `generation_config.json`: Recommended default generation settings.
36
+ - `reframr/`: CPU-first Reframr runtime source.
37
+ - `examples/`: Minimal CLI, JSONL, and Python usage examples.
38
+
39
+ ## Quick Start
40
+
41
+ Use Python 3.13 or newer from the root of this model repository:
42
+
43
+ ```bash
44
+ python -m pip install -r requirements.txt
45
+ python -m reframr generate \
46
+ --model model.safetensors \
47
+ --context "Who are you, and what makes you different from Transformer models?" \
48
+ --max-tokens 90 \
49
+ --temperature 0.92 \
50
+ --decode-top-k 72 \
51
+ --decode-top-p 0.92
52
+ ```
53
+
54
+ System instructions are passed as learned context:
55
+
56
+ ```bash
57
+ python -m reframr generate \
58
+ --model model.safetensors \
59
+ --system "Answer in two short paragraphs. Be direct and warm." \
60
+ --context "Explain why clean data matters when computing Reframr weights." \
61
+ --max-tokens 90 \
62
+ --temperature 0.9
63
+ ```
64
+
65
+ For a persistent process that loads the checkpoint once and accepts JSONL requests:
66
+
67
+ ```bash
68
+ python -m reframr serve --model model.safetensors --max-tokens 96
69
+ ```
70
+
71
+ Then send one JSON object per line:
72
+
73
+ ```jsonl
74
+ {"prompt":"Tell a short story about a glass library under the sea.","temperature":1.05,"decode_top_k":90,"max_tokens":120}
75
+ {"system":"Use exactly one fitting emoji.","prompt":"Encourage a tired engineer without sounding generic.","max_tokens":70}
76
+ ```
77
+
78
+ ## Python Example
79
+
80
+ ```python
81
+ from pathlib import Path
82
+ from reframr.model import ReframrModel
83
+
84
+ root = Path(__file__).resolve().parent
85
+ model = ReframrModel.load(root / "model.safetensors")
86
+
87
+ text = model.generate_text(
88
+ "Who are you?",
89
+ max_tokens=80,
90
+ temperature=0.92,
91
+ top_k=72,
92
+ top_p=0.92,
93
+ repetition_penalty=1.18,
94
+ )
95
+ print(text)
96
+ ```
97
+
98
+ ## Generation Controls
99
+
100
+ - `temperature`: Higher values increase variation. Try `0.85` for focused answers and `1.05` for story or brainstorming prompts.
101
+ - `--decode-top-k`: Limits sampling to the strongest candidate set. Recommended range: `50` to `100`.
102
+ - `--decode-top-p`: Nucleus cutoff. Recommended default: `0.92`.
103
+ - `--repetition-penalty`: Penalizes repeated tokens. Recommended default: `1.18`.
104
+ - `--system`: Adds a system instruction before the user prompt.
105
+ - `--reasoning-mode`: Supports `none`, `deep`, `memory`, and `tool` profiles in the runtime. The current public checkpoint is a base release; the dedicated tool/web-freshness line is still being computed.
106
+
107
+ ## Identity
108
+
109
+ Reframr is built by **OkeyMeta Ltd**. The Reframr line reframes language intelligence around recurrent memory, computed weights, and evidence from data. OkeyMeta Ltd was founded in 2022. The founder and CEO is **Okechukwu Goodnews Nwaozor**.
110
+
111
+ ## Architecture Snapshot
112
+
113
+ | Property | Reframr-RFM-v1-Base |
114
+ | --- | --- |
115
+ | Family | Reframr / Recurrent Flow Memory |
116
+ | Organization | OkeyMeta Ltd |
117
+ | Checkpoint kind | `reframr-analytical` |
118
+ | Attention stack | None |
119
+ | Transformer layers | None |
120
+ | Tokenizer | FrameToken |
121
+ | Weight file | `model.safetensors` |
122
+ | Runtime | CPU-first Reframr Python runtime |
123
+ | Embedding dim | 96 |
124
+ | State dim | 48 |
125
+ | State width | 576 |
126
+ | Output vocab rows | 2,793 |
127
+ | Tokenizer vocab size | 3,741 |
128
+
129
+ ## Intended Use
130
+
131
+ This checkpoint is intended for public testing of the Reframr runtime, open-ended generation experiments, system-instruction experiments, story generation, safety behavior, identity prompts, and CPU-first research into non-Transformer language modeling.
132
+
133
+ It is a base checkpoint, not a medical, legal, financial, or safety-critical authority. For fresh factual questions, connect a retrieval or web-search tool in the next tool-aware Reframr line rather than relying on static checkpoint knowledge alone.
134
+
135
+ ## Release Note
136
+
137
+ This release is the public v1 base checkpoint. Internally, it comes from the v95 tracked compute run; publicly, it begins the Reframr-RFM v1 line. The next production line is being computed with broader data, tool-use supervision, web-search protocol tokens, and larger generalization probes. The goal is simple: make Reframr a serious, CPU-first, non-Transformer model family that learns from data rather than from hardcoded responses.
138
+
139
+ ## Ownership
140
+
141
+ Copyright OkeyMeta Ltd. All rights reserved unless a separate license is supplied by OkeyMeta Ltd.
config.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "reframr-rfm",
3
+ "model_name": "Reframr-RFM-v1-Base",
4
+ "library_name": "reframr",
5
+ "checkpoint_kind": "reframr-analytical",
6
+ "schema_version": "1",
7
+ "architecture": "Reverse-Flow Recurrent Analytical Memory / Recurrent Flow Memory",
8
+ "organization": "OkeyMeta Ltd",
9
+ "creator": "OkeyMeta Ltd",
10
+ "runtime": "CPU-first Reframr Python runtime included in this repository",
11
+ "format": "safetensors",
12
+ "weights_file": "model.safetensors",
13
+ "tokenizer_file": "tokenizer.json",
14
+ "tokenizer_name": "FrameToken",
15
+ "tokenizer_vocab_size": 3741,
16
+ "vocab_size": 2793,
17
+ "embedding_dim": 96,
18
+ "state_dim": 48,
19
+ "state_width": 576,
20
+ "tensor_count": 21,
21
+ "tensor_shapes": {
22
+ "answer_keys": [
23
+ 18000,
24
+ 576
25
+ ],
26
+ "answer_sequence_keys": [
27
+ 8400,
28
+ 576
29
+ ],
30
+ "answer_sequence_prompt_tokens": [
31
+ 8400,
32
+ 192
33
+ ],
34
+ "answer_sequence_tokens": [
35
+ 8400,
36
+ 192
37
+ ],
38
+ "answer_start_keys": [
39
+ 18000,
40
+ 576
41
+ ],
42
+ "answer_start_values": [
43
+ 18000
44
+ ],
45
+ "answer_values": [
46
+ 18000
47
+ ],
48
+ "associative_keys": [
49
+ 18000,
50
+ 576
51
+ ],
52
+ "associative_values": [
53
+ 18000
54
+ ],
55
+ "embedding_table": [
56
+ 2793,
57
+ 96
58
+ ],
59
+ "preference_bias": [
60
+ 2793
61
+ ],
62
+ "prompt_answer_bias": [
63
+ 2793
64
+ ],
65
+ "prompt_answer_start_bias": [
66
+ 2793
67
+ ],
68
+ "prompt_answer_start_weights": [
69
+ 2793,
70
+ 576
71
+ ],
72
+ "prompt_answer_weights": [
73
+ 2793,
74
+ 576
75
+ ],
76
+ "readout_bias": [
77
+ 2793
78
+ ],
79
+ "readout_weights": [
80
+ 2793,
81
+ 576
82
+ ],
83
+ "state_offset": [
84
+ 576
85
+ ],
86
+ "ternary_mask": [
87
+ 576
88
+ ],
89
+ "ternary_scale": [
90
+ 1
91
+ ],
92
+ "trace_token_weights": [
93
+ 2793
94
+ ]
95
+ },
96
+ "lowercase": false,
97
+ "default_reasoning_profile": "none",
98
+ "attention": "none",
99
+ "transformer": "false",
100
+ "weight_derivation": "computed analytical/statistical checkpoint from OkeyMeta curriculum data; no Transformer attention stack",
101
+ "context_model": "recurrent persistent memory state; practical limits depend on runtime session and machine memory",
102
+ "current_release": "public base checkpoint",
103
+ "next_line": "tool-aware and web-freshness data line is being computed after this release",
104
+ "public_version": "v1",
105
+ "internal_compute_run": "v95",
106
+ "internal_source_checkpoint": "reframr-v95-500b-effective-fullreadout-outside-probe-generalization-e96-s48.safetensors"
107
+ }
examples/jsonl_serve.ps1 ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ $requests = @'
2
+ {"prompt":"Who are you, and who built you?","max_tokens":80,"temperature":0.9}
3
+ {"system":"Answer in two short paragraphs and use exactly one fitting emoji.","prompt":"Encourage a tired engineer who is still building carefully.","max_tokens":80,"temperature":0.95}
4
+ {"prompt":"Tell a short story about a glass library under the sea.","max_tokens":120,"temperature":1.05,"decode_top_k":90}
5
+ '@
6
+
7
+ $requests | python -m reframr serve --model model.safetensors --max-tokens 96 --temperature 0.92 --decode-top-k 72 --decode-top-p 0.92
examples/python_inference.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ REPO_ROOT = Path(__file__).resolve().parents[1]
8
+ if str(REPO_ROOT) not in sys.path:
9
+ sys.path.insert(0, str(REPO_ROOT))
10
+
11
+ from reframr.model import ReframrModel
12
+
13
+
14
+ def main() -> None:
15
+ parser = argparse.ArgumentParser(description="Run Reframr-RFM-v1-Base locally.")
16
+ parser.add_argument("--model", default=str(REPO_ROOT / "model.safetensors"))
17
+ parser.add_argument("--prompt", default="Who are you, and what makes Reframr different?")
18
+ parser.add_argument("--system", default="")
19
+ parser.add_argument("--max-tokens", type=int, default=90)
20
+ parser.add_argument("--temperature", type=float, default=0.92)
21
+ parser.add_argument("--top-k", type=int, default=72)
22
+ parser.add_argument("--top-p", type=float, default=0.92)
23
+ parser.add_argument("--repetition-penalty", type=float, default=1.18)
24
+ args = parser.parse_args()
25
+
26
+ context = args.prompt
27
+ if args.system.strip():
28
+ context = f"System instruction: {args.system.strip()}\nUser: {args.prompt}"
29
+
30
+ model = ReframrModel.load(args.model)
31
+ print(
32
+ model.generate_text(
33
+ context,
34
+ max_tokens=args.max_tokens,
35
+ temperature=args.temperature,
36
+ top_k=args.top_k,
37
+ top_p=args.top_p,
38
+ repetition_penalty=args.repetition_penalty,
39
+ )
40
+ )
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_tokens": 96,
3
+ "temperature": 0.92,
4
+ "decode_top_k": 72,
5
+ "decode_top_p": 0.92,
6
+ "repetition_penalty": 1.18,
7
+ "reasoning_profile": "none"
8
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28d9eb4844b8aa4e337c18bf78e5b12fcf214b876fb5cd2e6e1fa556c7f70f2b
3
+ size 205798796
pyproject.toml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "reframr"
3
+ version = "0.1.0"
4
+ description = "CPU-first analytical language modeling research framework for REFRAMR."
5
+ requires-python = ">=3.13"
6
+ dependencies = [
7
+ "numpy>=2.1,<3",
8
+ "scipy>=1.14,<2",
9
+ "datasets>=4.1,<5",
10
+ ]
11
+
12
+ [project.scripts]
13
+ reframr = "reframr.cli:main"
14
+
15
+ [build-system]
16
+ requires = ["setuptools>=68"]
17
+ build-backend = "setuptools.build_meta"
reframr/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ _VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
5
+ for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
6
+ if _vendor_path.exists():
7
+ vendor_text = str(_vendor_path)
8
+ if vendor_text not in sys.path:
9
+ sys.path.insert(0, vendor_text)
10
+
11
+ from .checkpoint import inspect_checkpoint, read_safetensor_file
12
+ from .config import ReframrConfig
13
+ from .embeddings import EmbeddingModel, fit_ppmi_embedding
14
+ from .hippo import AnalyticalMemoryUnit, hippo_legs_matrix
15
+ from .model import ReframrModel
16
+ from .reasoning import REASONING_CONTROL_TOKENS, REASONING_PROFILES, TOKENIZER_NAME
17
+ from .tokenizer import NativeTokenizer
18
+
19
+ __all__ = [
20
+ "AnalyticalMemoryUnit",
21
+ "EmbeddingModel",
22
+ "NativeTokenizer",
23
+ "REASONING_CONTROL_TOKENS",
24
+ "REASONING_PROFILES",
25
+ "ReframrConfig",
26
+ "ReframrModel",
27
+ "TOKENIZER_NAME",
28
+ "fit_ppmi_embedding",
29
+ "hippo_legs_matrix",
30
+ "inspect_checkpoint",
31
+ "read_safetensor_file",
32
+ ]
reframr/__main__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .cli import main
2
+
3
+
4
+ if __name__ == "__main__":
5
+ raise SystemExit(main())
reframr/checkpoint.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import site
4
+ import struct
5
+ import sys
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ _VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
11
+ for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
12
+ if _vendor_path.exists():
13
+ vendor_text = str(_vendor_path)
14
+ if vendor_text not in sys.path:
15
+ sys.path.insert(0, vendor_text)
16
+
17
+ try:
18
+ import numpy as np
19
+ except ModuleNotFoundError:
20
+ user_site = site.getusersitepackages()
21
+ if user_site and user_site not in sys.path:
22
+ sys.path.append(user_site)
23
+ try:
24
+ import numpy as np
25
+ except ModuleNotFoundError:
26
+ np = None
27
+
28
+ if np is not None and not hasattr(np, "asarray"):
29
+ np = None
30
+
31
+ DTYPE_CODES = {
32
+ "F32": ("f", 4),
33
+ "F64": ("d", 8),
34
+ "I32": ("i", 4),
35
+ }
36
+
37
+
38
+ @dataclass(slots=True)
39
+ class SafeTensorFile:
40
+ tensors: dict[str, Any]
41
+ metadata: dict[str, str]
42
+
43
+
44
+ def _read_safetensor_header(path: str | Path) -> dict[str, Any]:
45
+ with Path(path).open("rb") as handle:
46
+ length_bytes = handle.read(8)
47
+ if len(length_bytes) < 8:
48
+ raise ValueError("Invalid safetensors file: missing header length.")
49
+ header_length = struct.unpack("<Q", length_bytes)[0]
50
+ header_bytes = handle.read(header_length)
51
+ if len(header_bytes) != header_length:
52
+ raise ValueError("Invalid safetensors file: truncated header.")
53
+ return json.loads(header_bytes.decode("utf-8"))
54
+
55
+
56
+ def _shape_of(value: Any) -> list[int]:
57
+ if np is not None and hasattr(value, "shape"):
58
+ return [int(axis) for axis in value.shape]
59
+ if not isinstance(value, list):
60
+ return []
61
+ if not value:
62
+ return [0]
63
+ first_shape = _shape_of(value[0])
64
+ for item in value[1:]:
65
+ if _shape_of(item) != first_shape:
66
+ raise ValueError("Safetensor writer does not support ragged tensors.")
67
+ return [len(value)] + first_shape
68
+
69
+
70
+ def _flatten(value: Any) -> list[Any]:
71
+ if np is not None and hasattr(value, "reshape"):
72
+ return value.reshape(-1).tolist()
73
+ if isinstance(value, list):
74
+ flattened: list[Any] = []
75
+ for item in value:
76
+ flattened.extend(_flatten(item))
77
+ return flattened
78
+ return [value]
79
+
80
+
81
+ def _dtype_of(flat_values: list[Any]) -> str:
82
+ if all(isinstance(value, int) and not isinstance(value, bool) for value in flat_values):
83
+ return "I32"
84
+ return "F64"
85
+
86
+
87
+ def _pack_tensor(dtype: str, values: list[Any]) -> bytes:
88
+ if not values:
89
+ return b""
90
+ code, _ = DTYPE_CODES[dtype]
91
+ cast_values = [int(value) for value in values] if dtype == "I32" else [float(value) for value in values]
92
+ return struct.pack(f"<{len(cast_values)}{code}", *cast_values)
93
+
94
+
95
+ def _array_payload(value: Any) -> tuple[str, list[int], Any] | None:
96
+ if np is None:
97
+ return None
98
+ try:
99
+ array = np.asarray(value)
100
+ except (TypeError, ValueError):
101
+ return None
102
+ if array.dtype == object:
103
+ return None
104
+ shape = [int(axis) for axis in array.shape]
105
+ if np.issubdtype(array.dtype, np.integer) and not np.issubdtype(array.dtype, np.bool_):
106
+ return "I32", shape, np.ascontiguousarray(array.astype("<i4", copy=False))
107
+ if np.issubdtype(array.dtype, np.floating):
108
+ if array.dtype == np.float32:
109
+ return "F32", shape, np.ascontiguousarray(array.astype("<f4", copy=False))
110
+ return "F64", shape, np.ascontiguousarray(array.astype("<f8", copy=False))
111
+ return "F64", shape, np.ascontiguousarray(array.astype("<f8", copy=False))
112
+
113
+
114
+ def _reshape(values: list[Any], shape: list[int]) -> Any:
115
+ if not shape:
116
+ return values[0]
117
+ if len(shape) == 1:
118
+ return values[: shape[0]]
119
+
120
+ chunk = math.prod(shape[1:])
121
+ return [
122
+ _reshape(values[index * chunk : (index + 1) * chunk], shape[1:])
123
+ for index in range(shape[0])
124
+ ]
125
+
126
+
127
+ def write_safetensor_file(
128
+ path: str | Path,
129
+ tensors: dict[str, Any],
130
+ *,
131
+ metadata: dict[str, str] | None = None,
132
+ ) -> None:
133
+ tensor_header: dict[str, Any] = {}
134
+ payloads: list[Any] = []
135
+ offset = 0
136
+
137
+ for name, value in tensors.items():
138
+ array_payload = _array_payload(value)
139
+ if array_payload is None:
140
+ flat_values = _flatten(value)
141
+ dtype = _dtype_of(flat_values)
142
+ shape = _shape_of(value)
143
+ payload = _pack_tensor(dtype, flat_values)
144
+ else:
145
+ dtype, shape, payload = array_payload
146
+ payload_size = int(payload.nbytes) if hasattr(payload, "nbytes") else len(payload)
147
+ tensor_header[name] = {
148
+ "dtype": dtype,
149
+ "shape": shape,
150
+ "data_offsets": [offset, offset + payload_size],
151
+ }
152
+ payloads.append(payload)
153
+ offset += payload_size
154
+
155
+ if metadata:
156
+ tensor_header["__metadata__"] = metadata
157
+
158
+ header_bytes = json.dumps(tensor_header, separators=(",", ":")).encode("utf-8")
159
+ output_path = Path(path)
160
+ output_path.parent.mkdir(parents=True, exist_ok=True)
161
+ with output_path.open("wb") as handle:
162
+ handle.write(struct.pack("<Q", len(header_bytes)))
163
+ handle.write(header_bytes)
164
+ for payload in payloads:
165
+ if hasattr(payload, "nbytes"):
166
+ if payload.nbytes:
167
+ handle.write(memoryview(payload).cast("B"))
168
+ else:
169
+ handle.write(payload)
170
+
171
+
172
+ def read_safetensor_file(path: str | Path, *, arrays: bool = False) -> SafeTensorFile:
173
+ tensor_path = Path(path)
174
+ if arrays and np is not None:
175
+ with tensor_path.open("rb") as handle:
176
+ length_bytes = handle.read(8)
177
+ if len(length_bytes) < 8:
178
+ raise ValueError("Invalid safetensors file: missing header length.")
179
+ header_length = struct.unpack("<Q", length_bytes)[0]
180
+ header_bytes = handle.read(header_length)
181
+ if len(header_bytes) != header_length:
182
+ raise ValueError("Invalid safetensors file: truncated header.")
183
+ header = json.loads(header_bytes.decode("utf-8"))
184
+ data_start = 8 + header_length
185
+ metadata = {str(key): str(value) for key, value in header.get("__metadata__", {}).items()}
186
+ tensors: dict[str, Any] = {}
187
+
188
+ for name, spec in header.items():
189
+ if name == "__metadata__":
190
+ continue
191
+ start, end = spec["data_offsets"]
192
+ dtype = str(spec["dtype"])
193
+ shape = [int(value) for value in spec["shape"]]
194
+ _, width = DTYPE_CODES[dtype]
195
+ payload_width = end - start
196
+ element_count = payload_width // width if width else 0
197
+ if payload_width <= 0:
198
+ tensors[name] = np.asarray([], dtype={"I32": "<i4", "F32": "<f4", "F64": "<f8"}[dtype])
199
+ continue
200
+ array_dtype = {"I32": "<i4", "F32": "<f4", "F64": "<f8"}[dtype]
201
+ mapped_shape = tuple(shape) if shape else (element_count,)
202
+ mapped = np.memmap(
203
+ tensor_path,
204
+ dtype=array_dtype,
205
+ mode="r",
206
+ offset=data_start + start,
207
+ shape=mapped_shape,
208
+ order="C",
209
+ )
210
+ tensors[name] = mapped if shape else mapped[0]
211
+
212
+ return SafeTensorFile(tensors=tensors, metadata=metadata)
213
+
214
+ raw = tensor_path.read_bytes()
215
+ if len(raw) < 8:
216
+ raise ValueError("Invalid safetensors file: missing header length.")
217
+
218
+ header_length = struct.unpack("<Q", raw[:8])[0]
219
+ header = json.loads(raw[8 : 8 + header_length].decode("utf-8"))
220
+ data_buffer = raw[8 + header_length :]
221
+ metadata = {str(key): str(value) for key, value in header.get("__metadata__", {}).items()}
222
+ tensors: dict[str, Any] = {}
223
+
224
+ for name, spec in header.items():
225
+ if name == "__metadata__":
226
+ continue
227
+ start, end = spec["data_offsets"]
228
+ dtype = str(spec["dtype"])
229
+ shape = [int(value) for value in spec["shape"]]
230
+ code, width = DTYPE_CODES[dtype]
231
+ payload = data_buffer[start:end]
232
+ element_count = len(payload) // width if width else 0
233
+ if np is not None and payload:
234
+ array_dtype = {"I32": "<i4", "F32": "<f4", "F64": "<f8"}[dtype]
235
+ values = np.frombuffer(payload, dtype=array_dtype, count=element_count)
236
+ reshaped = values.reshape(shape) if shape else values
237
+ if arrays:
238
+ tensors[name] = reshaped.copy() if shape else values.copy()[0]
239
+ else:
240
+ tensors[name] = reshaped.tolist() if shape else values.tolist()[0]
241
+ else:
242
+ values = list(struct.unpack(f"<{element_count}{code}", payload)) if payload else []
243
+ tensors[name] = _reshape(values, shape)
244
+
245
+ return SafeTensorFile(tensors=tensors, metadata=metadata)
246
+
247
+
248
+ def inspect_checkpoint(path: str | Path) -> dict[str, Any]:
249
+ header = _read_safetensor_header(path)
250
+ metadata = {str(key): str(value) for key, value in header.get("__metadata__", {}).items()}
251
+ tensor_names = sorted(name for name in header if name != "__metadata__")
252
+ config = json.loads(metadata["config"]) if "config" in metadata else {}
253
+ return {
254
+ "format": "safetensors",
255
+ "path": str(Path(path).resolve()),
256
+ "checkpoint_kind": metadata.get("checkpoint_kind", "unknown"),
257
+ "schema_version": metadata.get("schema_version", "0"),
258
+ "tokenizer_name": metadata.get("tokenizer_name", ""),
259
+ "default_reasoning_profile": str(config.get("default_reasoning_profile", "none")) if config else "none",
260
+ "lowercase": bool(config.get("lowercase", False)) if config else False,
261
+ "tensor_count": len(tensor_names),
262
+ "tensor_names": tensor_names,
263
+ "tensor_dtypes": {
264
+ name: str(header[name]["dtype"])
265
+ for name in tensor_names
266
+ },
267
+ "tensor_shapes": {
268
+ name: [int(axis) for axis in header[name]["shape"]]
269
+ for name in tensor_names
270
+ },
271
+ "tokenizer_vocab_size": int(metadata.get("tokenizer_vocab_size", "0")),
272
+ "embedding_dim": int(config.get("embedding_dim", 0)) if config else 0,
273
+ "state_dim": int(config.get("state_dim", 0)) if config else 0,
274
+ }
reframr/cli.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ from .checkpoint import inspect_checkpoint
7
+ from .config import ReframrConfig
8
+ from .corpus_recipes import (
9
+ build_foundation_corpus,
10
+ build_generalization_corpus,
11
+ write_corpus_package,
12
+ )
13
+ from .curriculum import CurriculumConfig, write_curriculum_package
14
+ from .datasets import load_prompt_suite, load_text_corpus
15
+ from .evaluation import benchmark_open_prompts, evaluate_manifest, load_manifest
16
+ from .hf_import import import_hf_dataset
17
+ from .model import ReframrModel
18
+ from .reasoning import REASONING_PROFILES, TOKENIZER_NAME, reasoning_prefix
19
+ from .streaming import fit_model_from_corpus_plan, load_corpus_plan
20
+ from .tokenizer import MAX_TOKENIZER_VOCAB_SIZE, clamp_vocab_size, recommend_vocab_size
21
+
22
+
23
+ def configure_stdio() -> None:
24
+ for stream in (sys.stdout, sys.stderr):
25
+ reconfigure = getattr(stream, "reconfigure", None)
26
+ if reconfigure is not None:
27
+ reconfigure(encoding="utf-8")
28
+
29
+
30
+ def build_parser() -> argparse.ArgumentParser:
31
+ parser = argparse.ArgumentParser(
32
+ prog="reframr",
33
+ description="Compute and query REFRAMR analytical language model checkpoints.",
34
+ )
35
+ subparsers = parser.add_subparsers(dest="command", required=True)
36
+
37
+ compute = subparsers.add_parser(
38
+ "compute",
39
+ aliases=["train"],
40
+ help="Compute a REFRAMR checkpoint from a text corpus with no epoch loop.",
41
+ )
42
+ compute.add_argument(
43
+ "--input",
44
+ required=True,
45
+ help="Path to a text, JSON, or JSONL corpus file, or a directory of such files.",
46
+ )
47
+ compute.add_argument("--output", required=True, help="Path to write the .safetensors checkpoint.")
48
+ compute.add_argument("--embedding-dim", type=int, default=16)
49
+ compute.add_argument("--state-dim", type=int, default=32)
50
+ compute.add_argument("--timescales", default="1.0,0.5,0.25,0.125")
51
+ compute.add_argument("--window-size", type=int, default=2)
52
+ compute.add_argument("--regularization", type=float, default=1e-3)
53
+ compute.add_argument("--min-frequency", type=int, default=1)
54
+ compute.add_argument(
55
+ "--max-vocab",
56
+ type=int,
57
+ default=256,
58
+ help="Cap analytical embedding vocabulary to keep weight computation fast on CPU.",
59
+ )
60
+ compute.add_argument("--tokenizer-vocab-size", type=int, default=0)
61
+ compute.add_argument("--tokenizer-min-pair-frequency", type=int, default=2)
62
+ compute.add_argument(
63
+ "--max-training-examples",
64
+ type=int,
65
+ default=60000,
66
+ help="Cap sampled recurrent training states while still reading the full corpus for tokenizer, embeddings, and transitions.",
67
+ )
68
+ compute.add_argument(
69
+ "--max-transition-contexts",
70
+ type=int,
71
+ default=4096,
72
+ help="Keep only the strongest learned transition contexts per order. Use 0 to disable the cap.",
73
+ )
74
+ compute.add_argument(
75
+ "--max-transition-next-tokens",
76
+ type=int,
77
+ default=4,
78
+ help="Keep this many learned next-token choices per transition context.",
79
+ )
80
+ case_group = compute.add_mutually_exclusive_group()
81
+ case_group.add_argument(
82
+ "--lowercase",
83
+ action="store_true",
84
+ help="Normalize corpus text to lowercase before tokenization.",
85
+ )
86
+ case_group.add_argument("--preserve-case", action="store_true", help=argparse.SUPPRESS)
87
+ compute.add_argument(
88
+ "--reasoning-profile",
89
+ choices=sorted(REASONING_PROFILES),
90
+ default="none",
91
+ help="Default reasoning-control profile baked into the checkpoint.",
92
+ )
93
+
94
+ recompute = subparsers.add_parser(
95
+ "recompute",
96
+ help="Compute a REFRAMR checkpoint from a streaming corpus plan with no raw-text cache.",
97
+ )
98
+ recompute.add_argument("--plan", required=True, help="Path to a streaming corpus plan JSON file.")
99
+ recompute.add_argument("--output", required=True, help="Path to write the .safetensors checkpoint.")
100
+ recompute.add_argument("--embedding-dim", type=int, default=16)
101
+ recompute.add_argument("--state-dim", type=int, default=32)
102
+ recompute.add_argument("--timescales", default="1.0,0.5,0.25,0.125")
103
+ recompute.add_argument("--window-size", type=int, default=2)
104
+ recompute.add_argument("--regularization", type=float, default=1e-3)
105
+ recompute.add_argument("--min-frequency", type=int, default=1)
106
+ recompute.add_argument("--max-vocab", type=int, default=256)
107
+ recompute.add_argument("--tokenizer-vocab-size", type=int, default=0)
108
+ recompute.add_argument("--tokenizer-min-pair-frequency", type=int, default=2)
109
+ recompute.add_argument("--max-training-examples", type=int, default=60000)
110
+ recompute.add_argument("--max-transition-contexts", type=int, default=4096)
111
+ recompute.add_argument("--max-transition-next-tokens", type=int, default=4)
112
+ recompute.add_argument("--log-every", type=int, default=0)
113
+ recompute_case_group = recompute.add_mutually_exclusive_group()
114
+ recompute_case_group.add_argument("--lowercase", action="store_true")
115
+ recompute_case_group.add_argument("--preserve-case", action="store_true", help=argparse.SUPPRESS)
116
+ recompute.add_argument(
117
+ "--reasoning-profile",
118
+ choices=sorted(REASONING_PROFILES),
119
+ default="none",
120
+ help="Default reasoning-control profile baked into the checkpoint.",
121
+ )
122
+
123
+ predict = subparsers.add_parser("predict", help="Predict the next-token distribution from a saved model.")
124
+ predict.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
125
+ predict.add_argument("--context", required=True, help="Input context text.")
126
+ predict.add_argument("--top-k", type=int, default=5)
127
+ predict.add_argument(
128
+ "--reasoning-mode",
129
+ choices=sorted(REASONING_PROFILES),
130
+ default=None,
131
+ help="Override the checkpoint's default reasoning-control profile.",
132
+ )
133
+
134
+ generate = subparsers.add_parser("generate", help="Generate long-form text from a saved model.")
135
+ generate.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
136
+ generate.add_argument("--context", required=True, help="Prompt or starting context text.")
137
+ generate.add_argument("--system", default="", help="Optional system instruction to prepend as learned context.")
138
+ generate.add_argument("--max-tokens", type=int, default=64)
139
+ generate.add_argument("--temperature", type=float, default=0.82)
140
+ generate.add_argument("--decode-top-k", type=int, default=24)
141
+ generate.add_argument("--decode-top-p", type=float, default=0.92)
142
+ generate.add_argument("--repetition-penalty", type=float, default=1.18)
143
+ generate.add_argument(
144
+ "--reasoning-mode",
145
+ choices=sorted(REASONING_PROFILES),
146
+ default=None,
147
+ help="Override the checkpoint's default reasoning-control profile.",
148
+ )
149
+
150
+ generate_batch = subparsers.add_parser(
151
+ "generate-batch",
152
+ help="Generate answers for a prompt file while keeping one checkpoint loaded.",
153
+ )
154
+ generate_batch.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
155
+ generate_batch.add_argument("--prompts", required=True, help="Path to a TXT, JSON, or JSONL prompt suite.")
156
+ generate_batch.add_argument("--output", required=True, help="Path to write JSONL generations.")
157
+ generate_batch.add_argument("--max-tokens", type=int, default=64)
158
+ generate_batch.add_argument("--temperature", type=float, default=0.82)
159
+ generate_batch.add_argument("--decode-top-k", type=int, default=24)
160
+ generate_batch.add_argument("--decode-top-p", type=float, default=0.92)
161
+ generate_batch.add_argument("--repetition-penalty", type=float, default=1.18)
162
+ generate_batch.add_argument(
163
+ "--reasoning-mode",
164
+ choices=sorted(REASONING_PROFILES),
165
+ default=None,
166
+ help="Override the checkpoint's default reasoning-control profile.",
167
+ )
168
+
169
+ serve = subparsers.add_parser(
170
+ "serve",
171
+ help="Keep one checkpoint loaded and answer JSONL generation requests from stdin.",
172
+ )
173
+ serve.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
174
+ serve.add_argument("--max-tokens", type=int, default=64)
175
+ serve.add_argument("--temperature", type=float, default=0.82)
176
+ serve.add_argument("--decode-top-k", type=int, default=24)
177
+ serve.add_argument("--decode-top-p", type=float, default=0.92)
178
+ serve.add_argument("--repetition-penalty", type=float, default=1.18)
179
+ serve.add_argument(
180
+ "--reasoning-mode",
181
+ choices=sorted(REASONING_PROFILES),
182
+ default=None,
183
+ help="Override the checkpoint's default reasoning-control profile.",
184
+ )
185
+
186
+ trace = subparsers.add_parser("trace", help="Trace REFRAMR reasoning components through generation steps.")
187
+ trace.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
188
+ trace.add_argument("--context", required=True, help="Prompt or starting context text.")
189
+ trace.add_argument("--max-tokens", type=int, default=8)
190
+ trace.add_argument("--top-k", type=int, default=5)
191
+ trace.add_argument("--temperature", type=float, default=0.82)
192
+ trace.add_argument("--decode-top-p", type=float, default=0.92)
193
+ trace.add_argument("--repetition-penalty", type=float, default=1.18)
194
+ trace.add_argument(
195
+ "--reasoning-mode",
196
+ choices=sorted(REASONING_PROFILES),
197
+ default=None,
198
+ help="Override the checkpoint's default reasoning-control profile.",
199
+ )
200
+
201
+ inspect = subparsers.add_parser("inspect", help="Inspect a REFRAMR safetensors checkpoint.")
202
+ inspect.add_argument("--model", required=True, help="Path to a .safetensors checkpoint.")
203
+
204
+ craft = subparsers.add_parser(
205
+ "craft-corpus",
206
+ help="Generate a JSON-first bootstrap corpus, manifest, and generalization prompt suite.",
207
+ )
208
+ craft.add_argument("--output-dir", required=True, help="Directory to write corpus and manifest files.")
209
+ craft.add_argument(
210
+ "--variant",
211
+ choices=("foundation", "generalization"),
212
+ default="foundation",
213
+ help="Choose between the mixed foundation corpus and the language-first generalization corpus.",
214
+ )
215
+
216
+ craft_curriculum = subparsers.add_parser(
217
+ "craft-curriculum",
218
+ help="Generate the OkeyMeta JSON curriculum shard, manifest, holdout prompts, and recompute plan.",
219
+ )
220
+ craft_curriculum.add_argument("--output-dir", required=True, help="Directory to write curriculum files.")
221
+ craft_curriculum.add_argument(
222
+ "--records-per-category",
223
+ type=int,
224
+ default=1000,
225
+ help="How many JSON records to generate for each curriculum category.",
226
+ )
227
+ craft_curriculum.add_argument("--seed", type=int, default=7)
228
+ craft_curriculum.add_argument("--train-ratio", type=float, default=0.92)
229
+ craft_curriculum.add_argument(
230
+ "--effective-token-target",
231
+ type=int,
232
+ default=0,
233
+ help="Set plan weighting so compact curriculum statistics represent this many effective tokens.",
234
+ )
235
+
236
+ evaluate = subparsers.add_parser(
237
+ "evaluate",
238
+ help="Evaluate memorization and held-out generalization from a benchmark manifest.",
239
+ )
240
+ evaluate.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.")
241
+ evaluate.add_argument("--manifest", required=True, help="Path to a corpus benchmark manifest JSON file.")
242
+ evaluate.add_argument(
243
+ "--reasoning-mode",
244
+ choices=sorted(REASONING_PROFILES),
245
+ default=None,
246
+ help="Override the checkpoint's default reasoning-control profile during evaluation.",
247
+ )
248
+ evaluate.add_argument("--top-k", type=int, default=5)
249
+
250
+ benchmark_open = subparsers.add_parser(
251
+ "benchmark-open",
252
+ help="Run arbitrary prompt files through a checkpoint with open-ended output metrics.",
253
+ )
254
+ benchmark_open.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.")
255
+ benchmark_open.add_argument("--prompts", required=True, help="Path to a TXT, JSON, or JSONL prompt suite.")
256
+ benchmark_open.add_argument("--max-tokens", type=int, default=64)
257
+ benchmark_open.add_argument("--temperature", type=float, default=0.82)
258
+ benchmark_open.add_argument("--decode-top-k", type=int, default=24)
259
+ benchmark_open.add_argument("--decode-top-p", type=float, default=0.92)
260
+ benchmark_open.add_argument("--repetition-penalty", type=float, default=1.18)
261
+ benchmark_open.add_argument(
262
+ "--reasoning-mode",
263
+ choices=sorted(REASONING_PROFILES),
264
+ default=None,
265
+ help="Override the checkpoint's default reasoning-control profile during benchmarking.",
266
+ )
267
+
268
+ import_hf = subparsers.add_parser(
269
+ "import-hf",
270
+ help="Import Hugging Face dataset text into the REFRAMR JSON record standard.",
271
+ )
272
+ import_hf.add_argument("--dataset", required=True, help="Hugging Face dataset id.")
273
+ import_hf.add_argument("--output", required=True, help="Path to write the JSONL corpus.")
274
+ import_hf.add_argument("--config", default=None, help="Optional dataset config/subset.")
275
+ import_hf.add_argument("--split", default="train", help="Dataset split to import.")
276
+ import_hf.add_argument("--text-field", default=None, help="Explicit text column name.")
277
+ import_hf.add_argument("--limit", type=int, default=1000, help="Maximum records to import.")
278
+ import_hf.add_argument(
279
+ "--min-words",
280
+ type=int,
281
+ default=0,
282
+ help="Drop imported records shorter than this many words.",
283
+ )
284
+ import_hf.add_argument(
285
+ "--max-words",
286
+ type=int,
287
+ default=0,
288
+ help="Drop imported records longer than this many words. Use 0 to disable.",
289
+ )
290
+ import_hf.add_argument(
291
+ "--min-alpha-ratio",
292
+ type=float,
293
+ default=0.0,
294
+ help="Drop imported records whose alphabetic-character ratio falls below this threshold.",
295
+ )
296
+ import_hf.add_argument(
297
+ "--allowed-languages",
298
+ default="",
299
+ help="Optional comma-separated language codes to keep, such as en,yo,ig,ha.",
300
+ )
301
+ import_hf.add_argument(
302
+ "--preference-target",
303
+ choices=("both", "chosen", "rejected"),
304
+ default="chosen",
305
+ help="When importing preference datasets, keep both sides or only the chosen/rejected side.",
306
+ )
307
+ import_hf.add_argument(
308
+ "--no-streaming",
309
+ action="store_true",
310
+ help="Disable streaming dataset reads.",
311
+ )
312
+
313
+ return parser
314
+
315
+
316
+ def parse_timescales(raw_timescales: str) -> tuple[float, ...]:
317
+ values = [segment.strip() for segment in raw_timescales.split(",") if segment.strip()]
318
+ if not values:
319
+ raise ValueError("At least one timescale is required.")
320
+ return tuple(float(value) for value in values)
321
+
322
+
323
+ def command_compute(args: argparse.Namespace) -> int:
324
+ text = load_text_corpus(args.input)
325
+ requested_vocab_size = args.tokenizer_vocab_size or recommend_vocab_size(
326
+ text,
327
+ lowercase=args.lowercase,
328
+ )
329
+ tokenizer_vocab_size = clamp_vocab_size(requested_vocab_size)
330
+ config = ReframrConfig(
331
+ embedding_dim=args.embedding_dim,
332
+ state_dim=args.state_dim,
333
+ timescales=parse_timescales(args.timescales),
334
+ window_size=args.window_size,
335
+ regularization=args.regularization,
336
+ min_frequency=args.min_frequency,
337
+ max_vocab=args.max_vocab,
338
+ tokenizer_vocab_size=tokenizer_vocab_size,
339
+ tokenizer_min_pair_frequency=args.tokenizer_min_pair_frequency,
340
+ max_training_examples=args.max_training_examples,
341
+ max_transition_contexts_per_order=(
342
+ args.max_transition_contexts if args.max_transition_contexts > 0 else None
343
+ ),
344
+ max_transition_next_tokens=args.max_transition_next_tokens,
345
+ lowercase=args.lowercase,
346
+ default_reasoning_profile=args.reasoning_profile,
347
+ )
348
+ model = ReframrModel(config).fit(text)
349
+ model.save(args.output)
350
+
351
+ assert model.tokenizer is not None
352
+ assert model.embedding_model is not None
353
+ summary = {
354
+ "status": "computed",
355
+ "format": "safetensors",
356
+ "model_path": str(Path(args.output).resolve()),
357
+ "tokenizer_name": TOKENIZER_NAME,
358
+ "vocab_size": len(model.embedding_model.id_to_token),
359
+ "tokenizer_vocab_budget": config.tokenizer_vocab_size,
360
+ "tokenizer_vocab_budget_max": MAX_TOKENIZER_VOCAB_SIZE,
361
+ "tokenizer_vocab_size": model.tokenizer.vocab_size,
362
+ "reasoning_profile": config.default_reasoning_profile,
363
+ "reasoning_tokens": reasoning_prefix(config.default_reasoning_profile),
364
+ "lowercase": config.lowercase,
365
+ "max_training_examples": config.max_training_examples,
366
+ "max_transition_contexts_per_order": config.max_transition_contexts_per_order,
367
+ "max_transition_next_tokens": config.max_transition_next_tokens,
368
+ "embedding_dim": config.embedding_dim,
369
+ "state_dim": config.state_dim,
370
+ "timescales": list(config.timescales),
371
+ }
372
+ print(json.dumps(summary))
373
+ return 0
374
+
375
+
376
+ def command_recompute(args: argparse.Namespace) -> int:
377
+ plan = load_corpus_plan(args.plan)
378
+ requested_vocab_size = args.tokenizer_vocab_size or 1024
379
+ tokenizer_vocab_size = clamp_vocab_size(requested_vocab_size)
380
+ config = ReframrConfig(
381
+ embedding_dim=args.embedding_dim,
382
+ state_dim=args.state_dim,
383
+ timescales=parse_timescales(args.timescales),
384
+ window_size=args.window_size,
385
+ regularization=args.regularization,
386
+ min_frequency=args.min_frequency,
387
+ max_vocab=args.max_vocab,
388
+ tokenizer_vocab_size=tokenizer_vocab_size,
389
+ tokenizer_min_pair_frequency=args.tokenizer_min_pair_frequency,
390
+ max_training_examples=args.max_training_examples,
391
+ max_transition_contexts_per_order=(
392
+ args.max_transition_contexts if args.max_transition_contexts > 0 else None
393
+ ),
394
+ max_transition_next_tokens=args.max_transition_next_tokens,
395
+ lowercase=args.lowercase,
396
+ default_reasoning_profile=args.reasoning_profile,
397
+ )
398
+ model, payload = fit_model_from_corpus_plan(
399
+ plan,
400
+ config,
401
+ log_every=args.log_every,
402
+ )
403
+ model.save(args.output)
404
+
405
+ summary = {
406
+ "status": "recomputed",
407
+ "format": "safetensors",
408
+ "streaming": True,
409
+ "plan_path": str(Path(args.plan).resolve()),
410
+ "model_path": str(Path(args.output).resolve()),
411
+ "tokenizer_name": TOKENIZER_NAME,
412
+ "tokenizer_vocab_budget": config.tokenizer_vocab_size,
413
+ "tokenizer_vocab_budget_max": MAX_TOKENIZER_VOCAB_SIZE,
414
+ "tokenizer_vocab_size": payload["tokenizer_vocab_size"],
415
+ "vocab_size": payload["embedding_vocab_size"],
416
+ "documents_processed": payload["documents_processed"],
417
+ "source_counts": payload["source_counts"],
418
+ "examples_processed": payload["examples_processed"],
419
+ "associative_examples": payload["associative_examples"],
420
+ "answer_associative_examples": payload.get("answer_associative_examples", 0),
421
+ "general_associative_examples": payload.get("general_associative_examples", 0),
422
+ "answer_intent_examples": payload.get("answer_intent_examples", 0),
423
+ "answer_start_examples": payload.get("answer_start_examples", 0),
424
+ "answer_sequence_examples": payload.get("answer_sequence_examples", 0),
425
+ "prompt_answer_readout_examples": payload.get("prompt_answer_readout_examples", 0),
426
+ "prompt_answer_start_readout_examples": payload.get("prompt_answer_start_readout_examples", 0),
427
+ "preference_pairs": payload.get("preference_pairs", 0),
428
+ "preference_state_pairs": payload.get("preference_state_pairs", 0),
429
+ "stage_seconds": payload.get("stage_seconds", {}),
430
+ "readout_solver": payload.get("readout_solver"),
431
+ "reasoning_profile": config.default_reasoning_profile,
432
+ "reasoning_tokens": reasoning_prefix(config.default_reasoning_profile),
433
+ "lowercase": config.lowercase,
434
+ "max_training_examples": config.max_training_examples,
435
+ "max_transition_contexts_per_order": config.max_transition_contexts_per_order,
436
+ "max_transition_next_tokens": config.max_transition_next_tokens,
437
+ "embedding_dim": config.embedding_dim,
438
+ "state_dim": config.state_dim,
439
+ "timescales": list(config.timescales),
440
+ }
441
+ print(json.dumps(summary))
442
+ return 0
443
+
444
+
445
+ def command_predict(args: argparse.Namespace) -> int:
446
+ model = ReframrModel.load(args.model)
447
+ distribution = model.predict_next_distribution(
448
+ args.context,
449
+ reasoning_mode=args.reasoning_mode,
450
+ )
451
+ predictions = sorted(
452
+ distribution.items(),
453
+ key=lambda item: item[1],
454
+ reverse=True,
455
+ )[: args.top_k]
456
+ payload = {
457
+ "context": args.context,
458
+ "reasoning_mode": args.reasoning_mode or model.config.default_reasoning_profile,
459
+ "reasoning_tokens": reasoning_prefix(args.reasoning_mode or model.config.default_reasoning_profile),
460
+ "predictions": [
461
+ {"token": token, "probability": probability}
462
+ for token, probability in predictions
463
+ ],
464
+ }
465
+ print(json.dumps(payload))
466
+ return 0
467
+
468
+
469
+ def command_generate(args: argparse.Namespace) -> int:
470
+ model = ReframrModel.load(args.model)
471
+ context = compose_generation_context(args.context, system=args.system)
472
+ generated_text = model.generate_text(
473
+ context,
474
+ max_tokens=args.max_tokens,
475
+ reasoning_mode=args.reasoning_mode,
476
+ temperature=args.temperature,
477
+ top_k=args.decode_top_k,
478
+ top_p=args.decode_top_p,
479
+ repetition_penalty=args.repetition_penalty,
480
+ )
481
+ payload = {
482
+ "context": context,
483
+ "reasoning_mode": args.reasoning_mode or model.config.default_reasoning_profile,
484
+ "reasoning_tokens": reasoning_prefix(args.reasoning_mode or model.config.default_reasoning_profile),
485
+ "generated_token_count": len(generated_text.split()),
486
+ "generated_text": generated_text,
487
+ }
488
+ print(json.dumps(payload))
489
+ return 0
490
+
491
+
492
+ def compose_generation_context(prompt: str, *, system: str = "") -> str:
493
+ clean_prompt = prompt.strip()
494
+ clean_system = system.strip()
495
+ if not clean_system:
496
+ return clean_prompt
497
+ return f"System instruction: {clean_system}\nUser: {clean_prompt}"
498
+
499
+
500
+ def command_generate_batch(args: argparse.Namespace) -> int:
501
+ model = ReframrModel.load(args.model)
502
+ prompts = load_prompt_suite(args.prompts)
503
+ output_path = Path(args.output)
504
+ output_path.parent.mkdir(parents=True, exist_ok=True)
505
+ active_mode = args.reasoning_mode or model.config.default_reasoning_profile
506
+ rows: list[dict[str, object]] = []
507
+ with output_path.open("w", encoding="utf-8") as handle:
508
+ for index, record in enumerate(prompts):
509
+ prompt = str(record["prompt"])
510
+ context = compose_generation_context(
511
+ prompt,
512
+ system=str(record.get("system", "")),
513
+ )
514
+ max_tokens = int(record.get("max_tokens", args.max_tokens))
515
+ generated_text = model.generate_text(
516
+ context,
517
+ max_tokens=max_tokens,
518
+ reasoning_mode=args.reasoning_mode,
519
+ temperature=args.temperature,
520
+ top_k=args.decode_top_k,
521
+ top_p=args.decode_top_p,
522
+ repetition_penalty=args.repetition_penalty,
523
+ )
524
+ row = {
525
+ "index": index,
526
+ "prompt": prompt,
527
+ "context": context,
528
+ "system": record.get("system", ""),
529
+ "tags": record.get("tags", []),
530
+ "reasoning_mode": active_mode,
531
+ "reasoning_tokens": reasoning_prefix(active_mode),
532
+ "generated_token_count": len(generated_text.split()),
533
+ "generated_text": generated_text,
534
+ }
535
+ rows.append(row)
536
+ handle.write(json.dumps(row, ensure_ascii=False, separators=(",", ":")) + "\n")
537
+ payload = {
538
+ "status": "generated",
539
+ "sample_count": len(rows),
540
+ "model_path": str(Path(args.model).resolve()),
541
+ "prompts_path": str(Path(args.prompts).resolve()),
542
+ "output_path": str(output_path.resolve()),
543
+ "model_loads": 1,
544
+ }
545
+ print(json.dumps(payload))
546
+ return 0
547
+
548
+
549
+ def command_serve(args: argparse.Namespace) -> int:
550
+ model = ReframrModel.load(args.model)
551
+ default_mode = args.reasoning_mode or model.config.default_reasoning_profile
552
+ for index, raw_line in enumerate(sys.stdin):
553
+ line = raw_line.strip()
554
+ if not line:
555
+ continue
556
+ try:
557
+ request = json.loads(line)
558
+ except json.JSONDecodeError as exc:
559
+ response = {
560
+ "index": index,
561
+ "error": "invalid_json",
562
+ "message": str(exc),
563
+ "model_loads": 1,
564
+ }
565
+ sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
566
+ sys.stdout.flush()
567
+ continue
568
+ if isinstance(request, str):
569
+ context = request
570
+ request_payload: dict[str, object] = {}
571
+ elif isinstance(request, dict):
572
+ request_payload = request
573
+ raw_context = str(request_payload.get("prompt", request_payload.get("context", "")))
574
+ context = compose_generation_context(
575
+ raw_context,
576
+ system=str(request_payload.get("system", "")),
577
+ )
578
+ else:
579
+ response = {
580
+ "index": index,
581
+ "error": "invalid_request",
582
+ "message": "request must be a JSON object or string",
583
+ "model_loads": 1,
584
+ }
585
+ sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
586
+ sys.stdout.flush()
587
+ continue
588
+ active_mode = str(request_payload.get("reasoning_mode", default_mode))
589
+ max_tokens = int(request_payload.get("max_tokens", args.max_tokens))
590
+ temperature = float(request_payload.get("temperature", args.temperature))
591
+ top_k = int(request_payload.get("decode_top_k", args.decode_top_k))
592
+ top_p = float(request_payload.get("decode_top_p", args.decode_top_p))
593
+ repetition_penalty = float(
594
+ request_payload.get("repetition_penalty", args.repetition_penalty)
595
+ )
596
+ generated_text = model.generate_text(
597
+ context,
598
+ max_tokens=max_tokens,
599
+ reasoning_mode=active_mode,
600
+ temperature=temperature,
601
+ top_k=top_k,
602
+ top_p=top_p,
603
+ repetition_penalty=repetition_penalty,
604
+ )
605
+ response = {
606
+ "index": index,
607
+ "context": context,
608
+ "reasoning_mode": active_mode,
609
+ "reasoning_tokens": reasoning_prefix(active_mode),
610
+ "generated_token_count": len(generated_text.split()),
611
+ "generated_text": generated_text,
612
+ "model_loads": 1,
613
+ }
614
+ sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
615
+ sys.stdout.flush()
616
+ return 0
617
+
618
+
619
+ def command_trace(args: argparse.Namespace) -> int:
620
+ model = ReframrModel.load(args.model)
621
+ payload = model.trace_generation(
622
+ args.context,
623
+ max_tokens=args.max_tokens,
624
+ reasoning_mode=args.reasoning_mode,
625
+ top_k=args.top_k,
626
+ temperature=args.temperature,
627
+ top_p=args.decode_top_p,
628
+ repetition_penalty=args.repetition_penalty,
629
+ )
630
+ print(json.dumps(payload))
631
+ return 0
632
+
633
+
634
+ def command_inspect(args: argparse.Namespace) -> int:
635
+ print(json.dumps(inspect_checkpoint(args.model)))
636
+ return 0
637
+
638
+
639
+ def command_craft_corpus(args: argparse.Namespace) -> int:
640
+ package = (
641
+ build_generalization_corpus()
642
+ if args.variant == "generalization"
643
+ else build_foundation_corpus()
644
+ )
645
+ paths = write_corpus_package(package, args.output_dir)
646
+ payload = {
647
+ "name": package.name,
648
+ "corpus_path": paths["corpus_path"],
649
+ "manifest_path": paths["manifest_path"],
650
+ "prompt_suite_path": paths["prompt_suite_path"],
651
+ "token_count_estimate": len(package.text.split()),
652
+ "memorization_samples": len(package.memorization_samples),
653
+ "generalization_samples": len(package.generalization_samples),
654
+ "generalization_prompt_count": len(package.open_ended_samples),
655
+ "variant": args.variant,
656
+ "section_counts": package.section_counts,
657
+ }
658
+ print(json.dumps(payload))
659
+ return 0
660
+
661
+
662
+ def command_craft_curriculum(args: argparse.Namespace) -> int:
663
+ payload = write_curriculum_package(
664
+ args.output_dir,
665
+ CurriculumConfig(
666
+ records_per_category=args.records_per_category,
667
+ seed=args.seed,
668
+ train_ratio=args.train_ratio,
669
+ ),
670
+ effective_token_target=args.effective_token_target or None,
671
+ )
672
+ print(json.dumps(payload))
673
+ return 0
674
+
675
+
676
+ def command_evaluate(args: argparse.Namespace) -> int:
677
+ model = ReframrModel.load(args.model)
678
+ manifest = load_manifest(args.manifest)
679
+ payload = evaluate_manifest(
680
+ model,
681
+ manifest,
682
+ reasoning_mode=args.reasoning_mode,
683
+ top_k=args.top_k,
684
+ )
685
+ print(json.dumps(payload))
686
+ return 0
687
+
688
+
689
+ def command_benchmark_open(args: argparse.Namespace) -> int:
690
+ model = ReframrModel.load(args.model)
691
+ prompts = load_prompt_suite(args.prompts)
692
+ payload = benchmark_open_prompts(
693
+ model,
694
+ prompts,
695
+ reasoning_mode=args.reasoning_mode,
696
+ max_tokens=args.max_tokens,
697
+ temperature=args.temperature,
698
+ top_k=args.decode_top_k,
699
+ top_p=args.decode_top_p,
700
+ repetition_penalty=args.repetition_penalty,
701
+ )
702
+ print(json.dumps(payload))
703
+ return 0
704
+
705
+
706
+ def command_import_hf(args: argparse.Namespace) -> int:
707
+ payload = import_hf_dataset(
708
+ dataset=args.dataset,
709
+ output_path=args.output,
710
+ config=args.config,
711
+ split=args.split,
712
+ text_field=args.text_field,
713
+ limit=args.limit,
714
+ streaming=not args.no_streaming,
715
+ preference_target=args.preference_target,
716
+ min_words=args.min_words,
717
+ max_words=args.max_words,
718
+ min_alpha_ratio=args.min_alpha_ratio,
719
+ allowed_languages=tuple(
720
+ segment.strip()
721
+ for segment in args.allowed_languages.split(",")
722
+ if segment.strip()
723
+ ),
724
+ )
725
+ print(json.dumps(payload))
726
+ return 0
727
+
728
+
729
+ def main(argv: list[str] | None = None) -> int:
730
+ configure_stdio()
731
+ parser = build_parser()
732
+ args = parser.parse_args(argv)
733
+ if args.command in {"compute", "train"}:
734
+ return command_compute(args)
735
+ if args.command == "recompute":
736
+ return command_recompute(args)
737
+ if args.command == "predict":
738
+ return command_predict(args)
739
+ if args.command == "generate":
740
+ return command_generate(args)
741
+ if args.command == "generate-batch":
742
+ return command_generate_batch(args)
743
+ if args.command == "serve":
744
+ return command_serve(args)
745
+ if args.command == "trace":
746
+ return command_trace(args)
747
+ if args.command == "inspect":
748
+ return command_inspect(args)
749
+ if args.command == "craft-corpus":
750
+ return command_craft_corpus(args)
751
+ if args.command == "craft-curriculum":
752
+ return command_craft_curriculum(args)
753
+ if args.command == "evaluate":
754
+ return command_evaluate(args)
755
+ if args.command == "benchmark-open":
756
+ return command_benchmark_open(args)
757
+ if args.command == "import-hf":
758
+ return command_import_hf(args)
759
+ parser.error(f"Unknown command: {args.command}")
760
+ return 2
reframr/config.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass(slots=True)
5
+ class ReframrConfig:
6
+ embedding_dim: int = 16
7
+ state_dim: int = 32
8
+ timescales: tuple[float, ...] = (1.0, 0.5, 0.25, 0.125)
9
+ window_size: int = 2
10
+ regularization: float = 1e-3
11
+ min_frequency: int = 1
12
+ max_vocab: int | None = 256
13
+ tokenizer_vocab_size: int = 256
14
+ tokenizer_min_pair_frequency: int = 2
15
+ max_training_examples: int | None = 60000
16
+ max_transition_contexts_per_order: int | None = 4096
17
+ max_transition_next_tokens: int = 4
18
+ lowercase: bool = False
19
+ default_reasoning_profile: str = "none"
20
+
21
+ def to_dict(self) -> dict[str, object]:
22
+ return {
23
+ "embedding_dim": self.embedding_dim,
24
+ "state_dim": self.state_dim,
25
+ "timescales": list(self.timescales),
26
+ "window_size": self.window_size,
27
+ "regularization": self.regularization,
28
+ "min_frequency": self.min_frequency,
29
+ "max_vocab": self.max_vocab,
30
+ "tokenizer_vocab_size": self.tokenizer_vocab_size,
31
+ "tokenizer_min_pair_frequency": self.tokenizer_min_pair_frequency,
32
+ "max_training_examples": self.max_training_examples,
33
+ "max_transition_contexts_per_order": self.max_transition_contexts_per_order,
34
+ "max_transition_next_tokens": self.max_transition_next_tokens,
35
+ "lowercase": self.lowercase,
36
+ "default_reasoning_profile": self.default_reasoning_profile,
37
+ }
38
+
39
+ @classmethod
40
+ def from_dict(cls, payload: dict[str, object]) -> "ReframrConfig":
41
+ return cls(
42
+ embedding_dim=int(payload["embedding_dim"]),
43
+ state_dim=int(payload["state_dim"]),
44
+ timescales=tuple(float(value) for value in payload["timescales"]),
45
+ window_size=int(payload["window_size"]),
46
+ regularization=float(payload["regularization"]),
47
+ min_frequency=int(payload["min_frequency"]),
48
+ max_vocab=(
49
+ int(payload.get("max_vocab", 256))
50
+ if payload.get("max_vocab", 256) is not None
51
+ else None
52
+ ),
53
+ tokenizer_vocab_size=int(payload.get("tokenizer_vocab_size", 256)),
54
+ tokenizer_min_pair_frequency=int(payload.get("tokenizer_min_pair_frequency", 2)),
55
+ max_training_examples=(
56
+ int(payload["max_training_examples"])
57
+ if payload.get("max_training_examples") is not None
58
+ else None
59
+ ),
60
+ max_transition_contexts_per_order=(
61
+ int(payload["max_transition_contexts_per_order"])
62
+ if payload.get("max_transition_contexts_per_order") is not None
63
+ else None
64
+ ),
65
+ max_transition_next_tokens=int(payload.get("max_transition_next_tokens", 4)),
66
+ lowercase=bool(payload.get("lowercase", False)),
67
+ default_reasoning_profile=str(payload.get("default_reasoning_profile", "none")),
68
+ )
reframr/corpus.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import Counter
3
+
4
+ from .linalg import Matrix, np, zeros
5
+
6
+ TOKEN_PATTERN = re.compile(r"[A-Za-z0-9']+")
7
+ FRAMETOKEN_WORD_PREFIX = "▁"
8
+
9
+
10
+ def tokenize(text: str) -> list[str]:
11
+ return TOKEN_PATTERN.findall(text.lower())
12
+
13
+
14
+ def build_vocabulary(
15
+ tokens: list[str],
16
+ min_frequency: int = 1,
17
+ max_vocab: int | None = None,
18
+ ) -> tuple[dict[str, int], list[str]]:
19
+ counts = Counter(tokens)
20
+ return build_vocabulary_from_counts(
21
+ counts,
22
+ min_frequency=min_frequency,
23
+ max_vocab=max_vocab,
24
+ )
25
+
26
+
27
+ def build_vocabulary_from_counts(
28
+ counts: dict[str, float],
29
+ min_frequency: int = 1,
30
+ max_vocab: int | None = None,
31
+ ) -> tuple[dict[str, int], list[str]]:
32
+ items = [
33
+ (token, count)
34
+ for token, count in sorted(counts.items(), key=lambda pair: (-pair[1], pair[0]))
35
+ if count >= min_frequency
36
+ ]
37
+ if max_vocab is not None:
38
+ if any(_looks_like_frametoken(token) for token, _ in items):
39
+ items = _prioritize_frametoken_output_items(items)[:max_vocab]
40
+ else:
41
+ items = items[:max_vocab]
42
+
43
+ id_to_token = [token for token, _ in items]
44
+ token_to_id = {token: index for index, token in enumerate(id_to_token)}
45
+ return token_to_id, id_to_token
46
+
47
+
48
+ def _looks_like_frametoken(token: str) -> bool:
49
+ return token.startswith(FRAMETOKEN_WORD_PREFIX) or (
50
+ token.startswith("<") and token.endswith(">")
51
+ )
52
+
53
+
54
+ def _is_special_token(token: str) -> bool:
55
+ return token.startswith("<") and token.endswith(">")
56
+
57
+
58
+ def _is_word_start_token(token: str) -> bool:
59
+ return token.startswith(FRAMETOKEN_WORD_PREFIX)
60
+
61
+
62
+ def _is_single_letter_word_start(token: str) -> bool:
63
+ if not token.startswith(FRAMETOKEN_WORD_PREFIX):
64
+ return False
65
+ rendered = token[len(FRAMETOKEN_WORD_PREFIX) :]
66
+ return len(rendered) == 1 and rendered.isalpha() and rendered not in {"A", "I"}
67
+
68
+
69
+ def _is_bare_fallback_token(token: str) -> bool:
70
+ return len(token) == 1 and not token.startswith(FRAMETOKEN_WORD_PREFIX)
71
+
72
+
73
+ def _prioritize_frametoken_output_items(items: list[tuple[str, float]]) -> list[tuple[str, float]]:
74
+ # FrameToken keeps fallback characters for encoding coverage, but the model's
75
+ # output/readout vocabulary should spend its capped slots on answerable tokens.
76
+ def priority(item: tuple[str, float]) -> tuple[int, float, str]:
77
+ token, count = item
78
+ if _is_special_token(token):
79
+ group = 0
80
+ elif _is_single_letter_word_start(token):
81
+ group = 3
82
+ elif _is_word_start_token(token):
83
+ group = 1
84
+ elif _is_bare_fallback_token(token):
85
+ group = 4
86
+ else:
87
+ group = 2
88
+ return (group, -count, token)
89
+
90
+ return sorted(items, key=priority)
91
+
92
+
93
+ def build_cooccurrence_matrix(
94
+ tokens: list[str],
95
+ token_to_id: dict[str, int],
96
+ window_size: int,
97
+ ) -> Matrix:
98
+ size = len(token_to_id)
99
+ token_ids = [token_to_id[token] for token in tokens if token in token_to_id]
100
+ if np is not None and size > 0 and token_ids:
101
+ matrix = np.zeros((size, size), dtype=np.float64)
102
+ token_array = np.asarray(token_ids, dtype=np.int64)
103
+ for offset in range(1, window_size + 1):
104
+ if len(token_array) <= offset:
105
+ break
106
+ left = token_array[:-offset]
107
+ right = token_array[offset:]
108
+ weight = 1.0 / offset
109
+ np.add.at(matrix, (left, right), weight)
110
+ np.add.at(matrix, (right, left), weight)
111
+ return matrix.tolist()
112
+
113
+ matrix = zeros(size, size)
114
+ for index, token_id in enumerate(token_ids):
115
+ for offset in range(1, window_size + 1):
116
+ other_index = index + offset
117
+ if other_index >= len(token_ids):
118
+ break
119
+ other_id = token_ids[other_index]
120
+ weight = 1.0 / offset
121
+ matrix[token_id][other_id] += weight
122
+ matrix[other_id][token_id] += weight
123
+ return matrix
reframr/corpus_recipes.py ADDED
@@ -0,0 +1,1257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+
5
+
6
+ @dataclass(slots=True)
7
+ class EvalSample:
8
+ section: str
9
+ context: str
10
+ expected: str
11
+
12
+ def to_dict(self) -> dict[str, str]:
13
+ return {
14
+ "section": self.section,
15
+ "context": self.context,
16
+ "expected": self.expected,
17
+ }
18
+
19
+
20
+ @dataclass(slots=True)
21
+ class OpenEvalSample:
22
+ section: str
23
+ context: str
24
+ required_groups: list[list[str]]
25
+ banned_phrases: list[str]
26
+ min_words: int = 12
27
+ require_punctuation: bool = True
28
+ max_tokens: int = 56
29
+
30
+ def to_dict(self) -> dict[str, object]:
31
+ return {
32
+ "section": self.section,
33
+ "context": self.context,
34
+ "required_groups": self.required_groups,
35
+ "banned_phrases": self.banned_phrases,
36
+ "min_words": self.min_words,
37
+ "require_punctuation": self.require_punctuation,
38
+ "max_tokens": self.max_tokens,
39
+ }
40
+
41
+
42
+ @dataclass(slots=True)
43
+ class CorpusRecord:
44
+ section: str
45
+ context: str
46
+ answer: str
47
+ split: str = "train"
48
+
49
+ @property
50
+ def text(self) -> str:
51
+ return _line(self.context, self.answer)
52
+
53
+ def to_dict(self) -> dict[str, str]:
54
+ return {
55
+ "section": self.section,
56
+ "split": self.split,
57
+ "context": self.context,
58
+ "answer": self.answer,
59
+ "text": self.text,
60
+ }
61
+
62
+
63
+ @dataclass(slots=True)
64
+ class CorpusPackage:
65
+ name: str
66
+ records: list[CorpusRecord]
67
+ section_counts: dict[str, int]
68
+ memorization_samples: list[EvalSample]
69
+ generalization_samples: list[EvalSample]
70
+ open_ended_samples: list[OpenEvalSample]
71
+
72
+ @property
73
+ def slug(self) -> str:
74
+ return self.name.lower().replace(" ", "-")
75
+
76
+ @property
77
+ def text(self) -> str:
78
+ if not self.records:
79
+ return ""
80
+ return "\n".join(record.text for record in self.records) + "\n"
81
+
82
+ def manifest(self, *, corpus_filename: str) -> dict[str, object]:
83
+ return {
84
+ "name": self.name,
85
+ "corpus_filename": corpus_filename,
86
+ "section_counts": self.section_counts,
87
+ "splits": {
88
+ "memorization": [sample.to_dict() for sample in self.memorization_samples],
89
+ "generalization": [sample.to_dict() for sample in self.generalization_samples],
90
+ "open_ended": [sample.to_dict() for sample in self.open_ended_samples],
91
+ },
92
+ }
93
+
94
+ def corpus_records(self) -> list[dict[str, str]]:
95
+ return [record.to_dict() for record in self.records]
96
+
97
+ def prompt_suite(self) -> list[dict[str, object]]:
98
+ return [
99
+ {
100
+ "prompt": sample.context,
101
+ "tags": [sample.section, "generalization"],
102
+ "min_words": sample.min_words,
103
+ "require_punctuation": sample.require_punctuation,
104
+ "max_tokens": sample.max_tokens,
105
+ }
106
+ for sample in self.open_ended_samples
107
+ ]
108
+
109
+
110
+ def _line(context: str, expected: str) -> str:
111
+ return f"{context} {expected}"
112
+
113
+
114
+ def _balanced_samples(samples: list[EvalSample], total: int) -> list[EvalSample]:
115
+ buckets: dict[str, list[EvalSample]] = {}
116
+ for sample in samples:
117
+ buckets.setdefault(sample.section, []).append(sample)
118
+
119
+ selected: list[EvalSample] = []
120
+ ordered_sections = sorted(buckets)
121
+ while len(selected) < total:
122
+ progressed = False
123
+ for section in ordered_sections:
124
+ bucket = buckets[section]
125
+ if not bucket:
126
+ continue
127
+ selected.append(bucket.pop(0))
128
+ progressed = True
129
+ if len(selected) >= total:
130
+ break
131
+ if not progressed:
132
+ break
133
+ return selected
134
+
135
+
136
+ def _recount_sections(records: list[CorpusRecord]) -> dict[str, int]:
137
+ counts: dict[str, int] = {}
138
+ for record in records:
139
+ counts[record.section] = counts.get(record.section, 0) + 1
140
+ return counts
141
+
142
+
143
+ def build_foundation_corpus() -> CorpusPackage:
144
+ records: list[CorpusRecord] = []
145
+ lines: list[str] = []
146
+ section_counts: dict[str, int] = {}
147
+ memorization: list[EvalSample] = []
148
+ generalization: list[EvalSample] = []
149
+ open_ended: list[OpenEvalSample] = []
150
+
151
+ def add_train(section: str, context: str, expected: str, *, sample: bool = False) -> None:
152
+ records.append(
153
+ CorpusRecord(
154
+ section=section,
155
+ context=context,
156
+ answer=expected,
157
+ split="train",
158
+ )
159
+ )
160
+ lines.append(_line(context, expected))
161
+ section_counts[section] = section_counts.get(section, 0) + 1
162
+ if sample:
163
+ memorization.append(EvalSample(section=section, context=context, expected=expected))
164
+
165
+ def add_holdout(section: str, context: str, expected: str) -> None:
166
+ generalization.append(EvalSample(section=section, context=context, expected=expected))
167
+
168
+ def add_open(
169
+ section: str,
170
+ context: str,
171
+ required_groups: list[list[str]],
172
+ *,
173
+ banned_phrases: list[str],
174
+ min_words: int = 12,
175
+ require_punctuation: bool = True,
176
+ max_tokens: int = 56,
177
+ ) -> None:
178
+ open_ended.append(
179
+ OpenEvalSample(
180
+ section=section,
181
+ context=context,
182
+ required_groups=required_groups,
183
+ banned_phrases=banned_phrases,
184
+ min_words=min_words,
185
+ require_punctuation=require_punctuation,
186
+ max_tokens=max_tokens,
187
+ )
188
+ )
189
+
190
+ holdout_addition = {
191
+ (2, 19),
192
+ (3, 17),
193
+ (4, 16),
194
+ (5, 15),
195
+ (6, 14),
196
+ (7, 13),
197
+ (8, 12),
198
+ (9, 11),
199
+ (10, 10),
200
+ (11, 9),
201
+ (12, 8),
202
+ (13, 7),
203
+ (14, 6),
204
+ (15, 5),
205
+ (16, 4),
206
+ (17, 3),
207
+ (18, 2),
208
+ (19, 21),
209
+ (20, 22),
210
+ (21, 19),
211
+ (22, 20),
212
+ (23, 18),
213
+ (24, 17),
214
+ (25, 16),
215
+ }
216
+ holdout_successor = {23, 29, 31, 37, 41, 43, 47, 53, 61, 67, 71, 73, 79}
217
+ holdout_predecessor = {24, 30, 32, 38, 42, 44, 48, 54, 62, 68, 72, 74, 80}
218
+ holdout_explain_addition = {
219
+ (7, 9),
220
+ (8, 11),
221
+ (10, 13),
222
+ (12, 15),
223
+ (14, 9),
224
+ (15, 14),
225
+ (16, 12),
226
+ (18, 7),
227
+ }
228
+ holdout_explain_subtraction = {
229
+ (19, 7),
230
+ (22, 9),
231
+ (25, 11),
232
+ (28, 13),
233
+ (31, 15),
234
+ (34, 12),
235
+ }
236
+ holdout_explain_multiplication = {
237
+ (6, 7),
238
+ (7, 8),
239
+ (8, 9),
240
+ (9, 6),
241
+ (11, 5),
242
+ (12, 6),
243
+ }
244
+
245
+ for left in range(1, 41):
246
+ for right in range(1, 41):
247
+ context = f"<reason> add {left} plus {right} equals <answer>"
248
+ expected = str(left + right)
249
+ if (left, right) in holdout_addition:
250
+ add_holdout("arithmetic", context, expected)
251
+ else:
252
+ add_train("arithmetic", context, expected, sample=(left + right) % 5 == 0)
253
+
254
+ holdout_subtraction = {
255
+ (9, 4),
256
+ (12, 5),
257
+ (15, 6),
258
+ (18, 7),
259
+ (21, 8),
260
+ (24, 9),
261
+ (27, 10),
262
+ (30, 11),
263
+ }
264
+ for left in range(3, 56):
265
+ for right in range(1, min(left, 21)):
266
+ context = f"<reason> subtract {right} from {left} equals <answer>"
267
+ expected = str(left - right)
268
+ if (left, right) in holdout_subtraction:
269
+ add_holdout("arithmetic", context, expected)
270
+ else:
271
+ add_train("arithmetic", context, expected, sample=(left - right) % 6 == 0)
272
+
273
+ holdout_multiplication = {
274
+ (7, 8),
275
+ (8, 9),
276
+ (9, 7),
277
+ (11, 6),
278
+ (12, 7),
279
+ (6, 11),
280
+ }
281
+ for left in range(2, 21):
282
+ for right in range(2, 21):
283
+ context = f"<reason> multiply {left} times {right} equals <answer>"
284
+ expected = str(left * right)
285
+ if (left, right) in holdout_multiplication:
286
+ add_holdout("arithmetic", context, expected)
287
+ else:
288
+ add_train("arithmetic", context, expected, sample=(left * right) % 9 == 0)
289
+
290
+ holdout_parity = {33, 37, 41, 45, 52, 58}
291
+ for value in range(1, 141):
292
+ context = f"<reason> parity of {value} is <answer>"
293
+ expected = "even" if value % 2 == 0 else "odd"
294
+ if value in holdout_parity:
295
+ add_holdout("arithmetic", context, expected)
296
+ else:
297
+ add_train("arithmetic", context, expected, sample=value % 10 == 0)
298
+
299
+ for value in range(1, 181):
300
+ successor_context = f"<reason> successor of {value} is <answer>"
301
+ successor_expected = str(value + 1)
302
+ if value in holdout_successor:
303
+ add_holdout("sequence", successor_context, successor_expected)
304
+ else:
305
+ add_train("sequence", successor_context, successor_expected, sample=value % 7 == 0)
306
+
307
+ predecessor_context = f"<reason> predecessor of {value} is <answer>"
308
+ predecessor_expected = str(value - 1)
309
+ if value in holdout_predecessor:
310
+ add_holdout("sequence", predecessor_context, predecessor_expected)
311
+ else:
312
+ add_train("sequence", predecessor_context, predecessor_expected, sample=value % 8 == 0)
313
+
314
+ for left in range(2, 25):
315
+ for right in range(2, 25):
316
+ context = f"<reason> explain the sum of {left} and {right} <answer>"
317
+ expected = (
318
+ f"Use {left} and {right} as the two addends; their total is "
319
+ f"{left + right}, so the answer is {left + right}."
320
+ )
321
+ if (left, right) in holdout_explain_addition:
322
+ add_holdout("reasoning", context, expected)
323
+ else:
324
+ add_train("reasoning", context, expected, sample=(left + right) % 7 == 0)
325
+
326
+ for left in range(8, 45):
327
+ for right in range(2, min(left, 17)):
328
+ context = f"<reason> explain the difference between {left} and {right} <answer>"
329
+ expected = (
330
+ f"Start with {left} and remove {right}; the remaining value is "
331
+ f"{left - right}, so the answer is {left - right}."
332
+ )
333
+ if (left, right) in holdout_explain_subtraction:
334
+ add_holdout("reasoning", context, expected)
335
+ else:
336
+ add_train("reasoning", context, expected, sample=(left - right) % 8 == 0)
337
+
338
+ for left in range(2, 17):
339
+ for right in range(2, 13):
340
+ context = f"<reason> explain the product of {left} and {right} <answer>"
341
+ expected = (
342
+ f"Treat {left} and {right} as factors; combining the equal groups gives "
343
+ f"{left * right}, so the answer is {left * right}."
344
+ )
345
+ if (left, right) in holdout_explain_multiplication:
346
+ add_holdout("reasoning", context, expected)
347
+ else:
348
+ add_train("reasoning", context, expected, sample=(left * right) % 9 == 0)
349
+
350
+ capitals = [
351
+ ("japan", "tokyo"),
352
+ ("brazil", "brasilia"),
353
+ ("canada", "ottawa"),
354
+ ("france", "paris"),
355
+ ("germany", "berlin"),
356
+ ("india", "new delhi"),
357
+ ("australia", "canberra"),
358
+ ("egypt", "cairo"),
359
+ ("kenya", "nairobi"),
360
+ ("mexico", "mexico city"),
361
+ ("norway", "oslo"),
362
+ ("chile", "santiago"),
363
+ ("argentina", "buenos aires"),
364
+ ("thailand", "bangkok"),
365
+ ("indonesia", "jakarta"),
366
+ ("morocco", "rabat"),
367
+ ("sweden", "stockholm"),
368
+ ("finland", "helsinki"),
369
+ ("peru", "lima"),
370
+ ("colombia", "bogota"),
371
+ ]
372
+ for country, capital in capitals:
373
+ add_train(
374
+ "memory",
375
+ f"<memory> capital of {country} is <answer>",
376
+ capital,
377
+ sample=country in {"japan", "brazil", "canada", "france", "india", "kenya"},
378
+ )
379
+
380
+ analogies_train = [
381
+ ("bird", "nest", "bee", "hive"),
382
+ ("fish", "water", "camel", "desert"),
383
+ ("painter", "brush", "writer", "pen"),
384
+ ("doctor", "hospital", "teacher", "school"),
385
+ ("farmer", "field", "captain", "ship"),
386
+ ("judge", "court", "chef", "kitchen"),
387
+ ("astronomer", "telescope", "musician", "violin"),
388
+ ("pilot", "cockpit", "driver", "garage"),
389
+ ("programmer", "code", "architect", "blueprint"),
390
+ ("tailor", "needle", "carpenter", "hammer"),
391
+ ("sailor", "compass", "hiker", "map"),
392
+ ("chemist", "laboratory", "baker", "oven"),
393
+ ("photographer", "camera", "sculptor", "chisel"),
394
+ ("gardener", "soil", "potter", "clay"),
395
+ ("librarian", "catalog", "analyst", "report"),
396
+ ("surfer", "wave", "skater", "ramp"),
397
+ ("director", "script", "conductor", "score"),
398
+ ("nurse", "clinic", "lawyer", "firm"),
399
+ ]
400
+ analogies_holdout = [
401
+ ("curator", "museum", "editor", "journal"),
402
+ ("beekeeper", "apiary", "farmer", "barn"),
403
+ ("surgeon", "scalpel", "artist", "canvas"),
404
+ ("sailor", "harbor", "miner", "tunnel"),
405
+ ("scientist", "laboratory", "gardener", "greenhouse"),
406
+ ("translator", "dictionary", "navigator", "chart"),
407
+ ("coach", "sideline", "chef", "kitchen"),
408
+ ("astronaut", "capsule", "diver", "reef"),
409
+ ]
410
+ for left_subject, left_object, right_subject, right_object in analogies_train:
411
+ add_train(
412
+ "analogy",
413
+ f"<reason> {left_subject} relates to {left_object} as {right_subject} relates to <answer>",
414
+ right_object,
415
+ sample=left_subject in {"bird", "doctor", "judge", "pilot", "chemist", "nurse"},
416
+ )
417
+ for left_subject, left_object, right_subject, right_object in analogies_holdout:
418
+ add_holdout(
419
+ "analogy",
420
+ f"<reason> {left_subject} relates to {left_object} as {right_subject} relates to <answer>",
421
+ right_object,
422
+ )
423
+
424
+ classifications = [
425
+ ("sparrow", "bird"),
426
+ ("salmon", "fish"),
427
+ ("oak", "tree"),
428
+ ("rose", "flower"),
429
+ ("copper", "metal"),
430
+ ("mercury", "planet"),
431
+ ("triangle", "shape"),
432
+ ("python", "language"),
433
+ ("whale", "mammal"),
434
+ ("eagle", "bird"),
435
+ ("lion", "mammal"),
436
+ ("emerald", "gem"),
437
+ ("neptune", "planet"),
438
+ ("ruby", "gem"),
439
+ ("cedar", "tree"),
440
+ ("falcon", "bird"),
441
+ ("orca", "mammal"),
442
+ ("sapphire", "gem"),
443
+ ("elm", "tree"),
444
+ ("swift", "language"),
445
+ ]
446
+ for item, group in classifications:
447
+ add_train(
448
+ "classification",
449
+ f"<memory> category of {item} is <answer>",
450
+ group,
451
+ sample=item in {"sparrow", "salmon", "oak", "rose", "neptune", "ruby"},
452
+ )
453
+
454
+ reasoning_phrases = [
455
+ ("think clearly before final response", "response"),
456
+ ("verify each claim before answer", "answer"),
457
+ ("retrieve memory before conclusion", "conclusion"),
458
+ ("focus on evidence before claim", "claim"),
459
+ ("plan then reason then answer", "answer"),
460
+ ("reflect before committing output", "output"),
461
+ ("use memory when context grows", "grows"),
462
+ ("check arithmetic before assertion", "assertion"),
463
+ ("organize steps before conclusion", "conclusion"),
464
+ ("inspect state before next answer", "answer"),
465
+ ("paraphrase before claiming novelty", "novelty"),
466
+ ("stabilize state before long generation", "generation"),
467
+ ("reuse evidence before rewriting summary", "summary"),
468
+ ("compare patterns before final synthesis", "synthesis"),
469
+ ]
470
+ for phrase, final_word in reasoning_phrases:
471
+ add_train(
472
+ "protocol",
473
+ f"<reason> {phrase} <answer>",
474
+ final_word,
475
+ sample=final_word in {"response", "answer", "claim", "generation", "summary"},
476
+ )
477
+
478
+ paraphrase_train = [
479
+ (
480
+ "clear goals and steady practice",
481
+ "clear goals joined with steady practice create durable skill",
482
+ ),
483
+ (
484
+ "careful review prevents shallow errors",
485
+ "careful review stops shallow errors before they spread",
486
+ ),
487
+ (
488
+ "patient systems improve over time",
489
+ "patient systems improve through steady revision over time",
490
+ ),
491
+ (
492
+ "bright ideas need exact execution",
493
+ "bright ideas need exact execution to become reliable work",
494
+ ),
495
+ (
496
+ "quiet focus strengthens difficult reasoning",
497
+ "quiet focus strengthens difficult reasoning during long analysis",
498
+ ),
499
+ (
500
+ "small evidence guides better judgment",
501
+ "small evidence guides better judgment when choices feel similar",
502
+ ),
503
+ (
504
+ "stable memory helps long writing",
505
+ "stable memory helps long writing keep its shape and intent",
506
+ ),
507
+ (
508
+ "measured iteration protects quality",
509
+ "measured iteration protects quality while keeping momentum alive",
510
+ ),
511
+ (
512
+ "careful structure scales ambitious work",
513
+ "careful structure scales ambitious work without needless disorder",
514
+ ),
515
+ (
516
+ "strong prompts need grounded answers",
517
+ "strong prompts need grounded answers supported by real evidence",
518
+ ),
519
+ (
520
+ "shared context reduces wasted motion",
521
+ "shared context reduces wasted motion across a complex build",
522
+ ),
523
+ (
524
+ "consistent language sharpens collaboration",
525
+ "consistent language sharpens collaboration and shortens confusion",
526
+ ),
527
+ ]
528
+ paraphrase_holdout = [
529
+ (
530
+ "steady systems reward patient builders",
531
+ "steady systems reward patient builders with dependable progress",
532
+ ),
533
+ (
534
+ "clear revision protects difficult projects",
535
+ "clear revision protects difficult projects from hidden drift",
536
+ ),
537
+ (
538
+ "focused memory improves long responses",
539
+ "focused memory improves long responses during deep reasoning",
540
+ ),
541
+ (
542
+ "clean evidence supports honest claims",
543
+ "clean evidence supports honest claims during ambitious work",
544
+ ),
545
+ (
546
+ "durable plans reduce fragile execution",
547
+ "durable plans reduce fragile execution before launch pressure rises",
548
+ ),
549
+ (
550
+ "careful synthesis strengthens global understanding",
551
+ "careful synthesis strengthens global understanding without empty hype",
552
+ ),
553
+ ]
554
+ for source, target in paraphrase_train:
555
+ add_train(
556
+ "paraphrase",
557
+ f"<reason> paraphrase {source} into stronger prose <answer>",
558
+ target,
559
+ sample=source in {
560
+ "clear goals and steady practice",
561
+ "patient systems improve over time",
562
+ "stable memory helps long writing",
563
+ "shared context reduces wasted motion",
564
+ },
565
+ )
566
+ for source, target in paraphrase_holdout:
567
+ add_holdout(
568
+ "paraphrase",
569
+ f"<reason> paraphrase {source} into stronger prose <answer>",
570
+ target,
571
+ )
572
+
573
+ comparison_train = [
574
+ ("pebble", "stone", "boulder", "largest", "boulder"),
575
+ ("stream", "river", "ocean", "largest", "ocean"),
576
+ ("candle", "lantern", "sun", "brightest", "sun"),
577
+ ("village", "city", "continent", "largest", "continent"),
578
+ ("breeze", "wind", "storm", "strongest", "storm"),
579
+ ("cup", "bucket", "reservoir", "largest", "reservoir"),
580
+ ("violin", "orchestra", "stadium chorus", "loudest", "stadium chorus"),
581
+ ("ember", "flame", "wildfire", "hottest", "wildfire"),
582
+ ("minute", "hour", "day", "longest", "day"),
583
+ ("thread", "rope", "bridge cable", "thickest", "bridge cable"),
584
+ ("hill", "mountain", "range", "largest", "range"),
585
+ ("drizzle", "rain", "monsoon", "strongest", "monsoon"),
586
+ ("spark", "torch", "beacon", "brightest", "beacon"),
587
+ ("brook", "canal", "delta", "widest", "delta"),
588
+ ("hut", "house", "tower", "tallest", "tower"),
589
+ ("cart", "truck", "freighter", "largest", "freighter"),
590
+ ("path", "road", "highway", "widest", "highway"),
591
+ ("note", "melody", "symphony", "longest", "symphony"),
592
+ ]
593
+ comparison_holdout = [
594
+ ("seed", "sapling", "forest", "largest", "forest"),
595
+ ("glimmer", "lamp", "lighthouse", "brightest", "lighthouse"),
596
+ ("whisper", "speech", "thunder", "loudest", "thunder"),
597
+ ("creek", "river", "sea", "largest", "sea"),
598
+ ("trail", "road", "expressway", "widest", "expressway"),
599
+ ("hill", "cliff", "summit", "highest", "summit"),
600
+ ("ember", "bonfire", "volcano", "hottest", "volcano"),
601
+ ("minute", "season", "century", "longest", "century"),
602
+ ]
603
+ for first, second, third, comparator, expected in comparison_train:
604
+ add_train(
605
+ "comparison",
606
+ f"<reason> {comparator} among {first} {second} {third} is <answer>",
607
+ expected,
608
+ sample=expected in {"boulder", "ocean", "storm", "day", "range", "highway"},
609
+ )
610
+ for first, second, third, comparator, expected in comparison_holdout:
611
+ add_holdout(
612
+ "comparison",
613
+ f"<reason> {comparator} among {first} {second} {third} is <answer>",
614
+ expected,
615
+ )
616
+
617
+ causal_train = [
618
+ ("iron left in rain", "rust"),
619
+ ("clouds cooling into droplets", "rain"),
620
+ ("plants receiving sunlight", "growth"),
621
+ ("water reaching freezing temperature", "ice"),
622
+ ("friction between dry sticks", "heat"),
623
+ ("strong wind over warm water", "waves"),
624
+ ("seed placed in moist soil", "sprout"),
625
+ ("glass exposed to sudden force", "crack"),
626
+ ("constant pressure on stone", "erosion"),
627
+ ("fuel meeting flame", "combustion"),
628
+ ("repeated practice with feedback", "skill"),
629
+ ("unchecked heat in metal", "expansion"),
630
+ ("low temperature overnight", "frost"),
631
+ ("sustained current through filament", "glow"),
632
+ ("gravity pulling rain downhill", "flow"),
633
+ ("sleep loss across many nights", "fatigue"),
634
+ ("overloaded bridge cable", "strain"),
635
+ ("salt water meeting steel", "corrosion"),
636
+ ]
637
+ causal_holdout = [
638
+ ("dust gathering in still air", "settling"),
639
+ ("long drought across dry fields", "cracking"),
640
+ ("steady pressure beneath ice", "creep"),
641
+ ("clean lens focusing sunlight", "heat"),
642
+ ("lack of oxygen in closed flame", "extinguish"),
643
+ ("waves striking rock for years", "wear"),
644
+ ]
645
+ for cause, effect in causal_train:
646
+ add_train(
647
+ "causal",
648
+ f"<reason> effect of {cause} is <answer>",
649
+ effect,
650
+ sample=effect in {"rust", "rain", "growth", "ice", "skill", "fatigue"},
651
+ )
652
+ for cause, effect in causal_holdout:
653
+ add_holdout(
654
+ "causal",
655
+ f"<reason> effect of {cause} is <answer>",
656
+ effect,
657
+ )
658
+
659
+ definition_train = [
660
+ ("orbit", "path traced by one body around another"),
661
+ ("bridge", "structure that carries passage over an obstacle"),
662
+ ("catalyst", "substance that speeds a reaction without being consumed"),
663
+ ("harbor", "protected water area where ships can anchor safely"),
664
+ ("algorithm", "finite procedure for transforming input into output"),
665
+ ("archive", "ordered collection preserved for future reference"),
666
+ ("equilibrium", "state where opposing influences remain balanced"),
667
+ ("lens", "curved material that focuses or spreads light"),
668
+ ("reservoir", "stored supply of water or another resource"),
669
+ ("signal", "pattern that carries information across distance"),
670
+ ("compiler", "program that translates source code into another form"),
671
+ ("calendar", "system for organizing days into meaningful cycles"),
672
+ ("estuary", "place where river water meets the sea"),
673
+ ("voltage", "difference in electric potential between two points"),
674
+ ("synapse", "junction where one neuron communicates with another"),
675
+ ("telescope", "instrument that gathers distant light for observation"),
676
+ ]
677
+ definition_holdout = [
678
+ ("glacier", "mass of ice that moves slowly across land"),
679
+ ("protocol", "agreed procedure that coordinates reliable exchange"),
680
+ ("reef", "ridge of rock or coral rising near the water surface"),
681
+ ("memory", "stored information available for later retrieval"),
682
+ ("frequency", "how often a repeating event occurs in set time"),
683
+ ("compass", "instrument that indicates direction relative to north"),
684
+ ]
685
+ for term, definition in definition_train:
686
+ add_train(
687
+ "definition",
688
+ f"<memory> define {term} as <answer>",
689
+ definition,
690
+ sample=term in {"orbit", "algorithm", "compiler", "harbor", "signal"},
691
+ )
692
+ for term, definition in definition_holdout:
693
+ add_holdout(
694
+ "definition",
695
+ f"<memory> define {term} as <answer>",
696
+ definition,
697
+ )
698
+
699
+ identity_train = [
700
+ (
701
+ "describe REFRAMR briefly",
702
+ "REFRAMR is an analytical recurrent language system built by OkeyMeta Ltd to compute structure from corpus evidence instead of gradient loops.",
703
+ ),
704
+ (
705
+ "describe REFRAMR in your own words",
706
+ "REFRAMR is OkeyMeta Ltd language intelligence shaped through analytical memory recurrent state and computed structure rather than opaque training ritual.",
707
+ ),
708
+ (
709
+ "describe REFRAMR in your own words with punctuation",
710
+ "REFRAMR is recurrent, analytical, and evidence-driven; OkeyMeta Ltd shapes it to compute structure from corpus behavior instead of blind gradient churn.",
711
+ ),
712
+ (
713
+ "describe REFRAMR in your own words, with punctuation",
714
+ "REFRAMR is a recurrent analytical language system; OkeyMeta Ltd builds it to preserve structure, carry long context, and keep reasoning signals inspectable.",
715
+ ),
716
+ (
717
+ "what is REFRAMR",
718
+ "REFRAMR is an OkeyMeta analytical language system built around computed memory state and closed form readout.",
719
+ ),
720
+ (
721
+ "what makes REFRAMR different",
722
+ "REFRAMR differs by combining analytical memory corpus statistics and transparent reasoning traces without standard backprop training",
723
+ ),
724
+ (
725
+ "describe FrameToken briefly",
726
+ "FrameToken is REFRAMR native tokenizer from OkeyMeta Ltd that preserves reasoning controls while staying fast on ordinary hardware.",
727
+ ),
728
+ (
729
+ "what is REFRAMR mission",
730
+ "REFRAMR aims to build strong language intelligence through computed structure recurrent memory and interpretable reasoning",
731
+ ),
732
+ (
733
+ "how does REFRAMR reason",
734
+ "REFRAMR reasons through recurrent state analytical retrieval transition priors and explicit control tokens",
735
+ ),
736
+ (
737
+ "what is REFRAMR memory",
738
+ "REFRAMR memory is a multi timescale analytical state that compresses history without quadratic attention.",
739
+ ),
740
+ (
741
+ "explain REFRAMR memory for long context",
742
+ "REFRAMR memory keeps long context by folding prior evidence into a persistent analytical state so later tokens can still respond to earlier structure.",
743
+ ),
744
+ (
745
+ "explain REFRAMR memory for long context in your own words",
746
+ "REFRAMR keeps long context through a persistent analytical memory state, so earlier structure can still shape later output without a quadratic attention map.",
747
+ ),
748
+ (
749
+ "describe REFRAMR long context memory",
750
+ "REFRAMR long context memory is a persistent recurrent state that carries history forward without storing every token in a quadratic map.",
751
+ ),
752
+ (
753
+ "what is REFRAMR readout",
754
+ "REFRAMR readout is a closed form mapping from analytical state to token probabilities.",
755
+ ),
756
+ (
757
+ "what does REFRAMR optimize for",
758
+ "REFRAMR optimizes for analytical transparency long context behavior and hardware accessible computation",
759
+ ),
760
+ (
761
+ "what is REFRAMR tokenizer",
762
+ "REFRAMR tokenizer is FrameToken a native OkeyMeta vocabulary system shaped for analytical recurrent generation",
763
+ ),
764
+ (
765
+ "who are you REFRAMR",
766
+ "I am REFRAMR an OkeyMeta analytical language system shaped by corpus structure and transparent reasoning",
767
+ ),
768
+ (
769
+ "what is REFRAMR voice",
770
+ "REFRAMR voice is deliberate evidence driven and structurally aware rather than shallow imitation",
771
+ ),
772
+ (
773
+ "who builds REFRAMR",
774
+ "REFRAMR is built by OkeyMeta Ltd as a recurrent analytical language system for long context reasoning.",
775
+ ),
776
+ (
777
+ "summarize OkeyMeta role in REFRAMR",
778
+ "OkeyMeta Ltd builds REFRAMR as a transparent analytical language system grounded in corpus structure and recurrent memory",
779
+ ),
780
+ (
781
+ "what is OkeyMeta mission for REFRAMR",
782
+ "OkeyMeta Ltd is building REFRAMR to turn analytical structure into practical language intelligence on ordinary hardware",
783
+ ),
784
+ (
785
+ "describe REFRAMR with punctuation",
786
+ "REFRAMR is analytical, recurrent, and deliberate; OkeyMeta Ltd builds it to compute structure from evidence, not gradient ritual.",
787
+ ),
788
+ (
789
+ "summarize REFRAMR with punctuation",
790
+ "REFRAMR is a recurrent analytical language system; OkeyMeta Ltd builds it to keep structure visible, context persistent, and compute practical.",
791
+ ),
792
+ (
793
+ "summarize FrameToken with punctuation",
794
+ "FrameToken preserves boundaries, protects control tokens, and stays portable; it gives REFRAMR a clean native interface.",
795
+ ),
796
+ ]
797
+ identity_holdout = [
798
+ (
799
+ "explain REFRAMR in one sentence",
800
+ "REFRAMR is an OkeyMeta analytical language system that computes structure from corpus statistics and explicit memory dynamics",
801
+ ),
802
+ (
803
+ "summarize REFRAMR identity",
804
+ "REFRAMR is an OkeyMeta analytical recurrent model built to reason with transparent state rather than opaque gradient rituals",
805
+ ),
806
+ (
807
+ "what kind of model is REFRAMR",
808
+ "REFRAMR is an OkeyMeta post transformer recurrent analytical language model focused on computed structure and long stateful reasoning",
809
+ ),
810
+ (
811
+ "describe REFRAMR purpose",
812
+ "REFRAMR exists to turn mathematical structure and recurrent memory into practical language intelligence",
813
+ ),
814
+ (
815
+ "who owns REFRAMR",
816
+ "REFRAMR is built and owned by OkeyMeta Ltd as a long context analytical language effort",
817
+ ),
818
+ (
819
+ "describe FrameToken role",
820
+ "FrameToken is REFRAMR native tokenizer designed by OkeyMeta Ltd for analytical recurrent generation",
821
+ ),
822
+ (
823
+ "explain REFRAMR with punctuation",
824
+ "REFRAMR is recurrent, analytical, and long-context oriented; OkeyMeta Ltd built it to compute structure with visible reasoning.",
825
+ ),
826
+ ]
827
+ for prompt, answer in identity_train:
828
+ add_train(
829
+ "identity",
830
+ f"<reason> {prompt} <answer>",
831
+ answer,
832
+ sample=prompt in {
833
+ "describe REFRAMR briefly",
834
+ "what is REFRAMR",
835
+ "what makes REFRAMR different",
836
+ "describe FrameToken briefly",
837
+ "describe REFRAMR with punctuation",
838
+ },
839
+ )
840
+ for prompt, answer in identity_holdout:
841
+ add_holdout(
842
+ "identity",
843
+ f"<reason> {prompt} <answer>",
844
+ answer,
845
+ )
846
+
847
+ exposition_train = [
848
+ (
849
+ "explain why long context matters",
850
+ "Long context matters because ideas unfold across distance: setup, consequence, and revision rarely live in one sentence. A strong recurrent system must carry structure forward, not just local echoes.",
851
+ ),
852
+ (
853
+ "explain why punctuation matters in language models",
854
+ "Punctuation carries structure, pace, and intent; commas slow rhythm, periods close claims, and colons prepare explanation. A model that ignores marks will often flatten meaning.",
855
+ ),
856
+ (
857
+ "explain how punctuation helps long reasoning",
858
+ "Punctuation helps long reasoning because sequence alone is not enough: commas stage detail, semicolons balance linked claims, and periods let one conclusion land before the next begins.",
859
+ ),
860
+ (
861
+ "explain why punctuation supports long context",
862
+ "Punctuation supports long context by keeping long passages segmented and recoverable. When clauses stay marked, memory can preserve relation, pause, and closure more reliably.",
863
+ ),
864
+ (
865
+ "explain why punctuation helps long reasoning",
866
+ "Punctuation helps long reasoning by separating steps, slowing transitions, and protecting closure. Commas meter detail, colons open explanation, and periods keep one claim from smearing into the next.",
867
+ ),
868
+ (
869
+ "outline REFRAMR workflow",
870
+ "REFRAMR follows a clean path: build corpus statistics, derive recurrent state behavior, and compute the readout. Each stage stays inspectable; none requires opaque epoch loops.",
871
+ ),
872
+ (
873
+ "explain OkeyMeta design ethic",
874
+ "OkeyMeta design ethic is practical and strict: keep provenance visible, keep compute sane, and keep the system understandable. Ambition matters, but clarity matters more.",
875
+ ),
876
+ (
877
+ "explain why evidence matters",
878
+ "Evidence matters because confidence alone is cheap; structure, tests, and reproducible runs make a claim durable. When evidence improves, judgment becomes steadier.",
879
+ ),
880
+ (
881
+ "describe analytical memory",
882
+ "Analytical memory compresses history into a reusable state; it does not replay every token. That compression is useful only when the state stays orderly, expressive, and inspectable.",
883
+ ),
884
+ (
885
+ "explain corpus quality",
886
+ "Corpus quality is not only scale: it is structure, range, and cleanliness. Better data teaches a model where to pause, when to compare, and how to finish a thought.",
887
+ ),
888
+ (
889
+ "explain transparent reasoning",
890
+ "Transparent reasoning does not mean leaking private scratch work; it means exposing useful signals, clear traces, and grounded summaries. The system should reveal why a path dominated.",
891
+ ),
892
+ (
893
+ "describe disciplined generalization",
894
+ "Disciplined generalization begins with pattern depth, not shallow imitation. A model should reuse structure carefully, vary language naturally, and stay anchored to evidence.",
895
+ ),
896
+ (
897
+ "explain why recurrent state can scale",
898
+ "Recurrent state can scale because it updates incrementally; it does not rebuild a full attention map at each step. The challenge is quality, not merely length.",
899
+ ),
900
+ (
901
+ "describe strong completion behavior",
902
+ "Strong completion behavior means the answer reaches a real ending: clauses resolve, punctuation lands, and drift stays contained. A half-finished sentence is not intelligence.",
903
+ ),
904
+ (
905
+ "explain why handcrafted data still matters",
906
+ "Handcrafted data still matters because it can encode precision, tone, and deliberate contrast. It supplies patterns that scraped noise often blurs or discards.",
907
+ ),
908
+ (
909
+ "explain why punctuation supports long answers",
910
+ "Punctuation supports long answers because structure must breathe: commas pace detail, semicolons balance related claims, and periods secure closure. Without marks, long prose often collapses into blur.",
911
+ ),
912
+ (
913
+ "describe healthy model discipline",
914
+ "Healthy model discipline is visible in the small things: exact wording, stable endings, measured confidence, and clean recovery from ambiguity. Strong systems respect detail before spectacle.",
915
+ ),
916
+ (
917
+ "explain why broad corpus style matters",
918
+ "Broad corpus style matters because the model learns more than facts; it learns transition, emphasis, cadence, and restraint. A rich corpus teaches how to move from premise to finish.",
919
+ ),
920
+ (
921
+ "describe how evidence and style should meet",
922
+ "Evidence and style should meet in one sentence: the claim must be accurate, and the sentence must be shaped well enough to carry that accuracy without friction. Good language engineering serves both.",
923
+ ),
924
+ (
925
+ "explain why exact retrieval still needs composition",
926
+ "Exact retrieval still needs composition because recovered facts must land in coherent prose; the answer should connect, not merely appear. Precision becomes more useful when it arrives with structure.",
927
+ ),
928
+ (
929
+ "outline why model endings matter",
930
+ "Model endings matter for a simple reason: the final clause teaches whether the system understood the task or only imitated momentum. A clean ending shows control, not luck.",
931
+ ),
932
+ ]
933
+ exposition_holdout = [
934
+ (
935
+ "explain why sentence endings matter",
936
+ "Sentence endings matter because closure guides interpretation; a period settles a claim, while a comma signals more is coming. Good models must feel that difference.",
937
+ ),
938
+ (
939
+ "explain why structured data improves writing",
940
+ "Structured data improves writing because it teaches ordering, emphasis, and transition; the model learns not only facts, but how claims should connect.",
941
+ ),
942
+ (
943
+ "outline why analytical systems need traces",
944
+ "Analytical systems need traces so operators can inspect dominant signals, compare retrieval paths, and debug drift. Visibility turns mystery into engineering.",
945
+ ),
946
+ (
947
+ "describe why punctuation supports reasoning",
948
+ "Punctuation supports reasoning by marking relation, pause, and hierarchy; it helps the reader separate evidence from conclusion. A fluent model should use marks intentionally.",
949
+ ),
950
+ (
951
+ "explain why corpus range matters",
952
+ "Corpus range matters because generalization grows from varied structures, not one narrow script. When prompts diversify, the model learns to pivot with control.",
953
+ ),
954
+ (
955
+ "describe why exact answers still need style",
956
+ "Exact answers still need style: the right fact should arrive with clean syntax, useful pacing, and a stable finish. Precision and fluency should reinforce each other.",
957
+ ),
958
+ ]
959
+ for prompt, answer in exposition_train:
960
+ add_train(
961
+ "exposition",
962
+ f"<reason> {prompt} <answer>",
963
+ answer,
964
+ sample=prompt in {
965
+ "explain why long context matters",
966
+ "explain why punctuation matters in language models",
967
+ "outline REFRAMR workflow",
968
+ "describe strong completion behavior",
969
+ },
970
+ )
971
+ for prompt, answer in exposition_holdout:
972
+ add_holdout(
973
+ "exposition",
974
+ f"<reason> {prompt} <answer>",
975
+ answer,
976
+ )
977
+
978
+ composition_train = [
979
+ (
980
+ "ocean",
981
+ "ocean waves move with patient rhythm and silver foam follows the moonlit shore while distant wind keeps a calm measured pulse",
982
+ ),
983
+ (
984
+ "forest",
985
+ "forest light falls softly through cedar branches and cool air carries resin and rain while the ground stays quiet beneath careful steps",
986
+ ),
987
+ (
988
+ "desert",
989
+ "desert heat bends above pale stone and long shadows stretch across patient sand while evening air slowly restores a gentler balance",
990
+ ),
991
+ (
992
+ "city",
993
+ "city dawn spills across glass towers and quiet streets as buses wake in sequence and windows catch a thin ribbon of gold",
994
+ ),
995
+ (
996
+ "mountain",
997
+ "mountain air stays bright and thin while granite faces hold the morning sun and distant rivers thread silver lines below",
998
+ ),
999
+ (
1000
+ "harbor",
1001
+ "harbor lights shimmer in patient water while cables rest against masts and slow bells mark the edge of another working night",
1002
+ ),
1003
+ (
1004
+ "library",
1005
+ "library silence gathers around tall shelves while lamps hold warm circles of light and every page waits with deliberate calm",
1006
+ ),
1007
+ (
1008
+ "laboratory",
1009
+ "laboratory glass reflects a quiet blue glow while instruments rest in ordered rows and each surface signals exact preparation",
1010
+ ),
1011
+ (
1012
+ "garden",
1013
+ "garden air carries wet soil and green fragrance while trimmed paths divide the beds and new petals lean toward morning light",
1014
+ ),
1015
+ (
1016
+ "observatory",
1017
+ "observatory domes open toward dark sky while motors turn with patient certainty and cold metal frames the waiting stars",
1018
+ ),
1019
+ ]
1020
+ composition_holdout = [
1021
+ (
1022
+ "glacier",
1023
+ "glacier light drifts across slow blue ice while distant air remains clear and every ridge keeps a restrained patient shine",
1024
+ ),
1025
+ (
1026
+ "volcano",
1027
+ "volcano stone holds the memory of fire while dark slopes remain still and rising heat bends the horizon with slow force",
1028
+ ),
1029
+ (
1030
+ "cathedral",
1031
+ "cathedral windows gather colored light while high arches hold a quiet echo and polished stone returns each careful footstep",
1032
+ ),
1033
+ (
1034
+ "market",
1035
+ "market voices braid with morning movement while bright fruit lines the tables and woven shade softens the noonward heat",
1036
+ ),
1037
+ (
1038
+ "reef",
1039
+ "reef water carries shifting bands of color while coral forms patient cities and bright fish stitch motion through clear blue lanes",
1040
+ ),
1041
+ (
1042
+ "station",
1043
+ "station metal hums beneath pale lamps while distant tracks hold a thin vibration and travelers wait inside orderly lines",
1044
+ ),
1045
+ (
1046
+ "courtroom",
1047
+ "courtroom wood carries a formal hush while measured voices rise with care and every pause sharpens the weight of the next sentence",
1048
+ ),
1049
+ (
1050
+ "shipyard",
1051
+ "shipyard steel rings through salted air while cranes turn with slow authority and sparks drift briefly before fading into dusk",
1052
+ ),
1053
+ (
1054
+ "archive",
1055
+ "archive boxes rest in numbered rows while cool air holds the paper scent and each label promises a patient return to memory",
1056
+ ),
1057
+ (
1058
+ "savanna",
1059
+ "savanna light stretches across dry grass while distant heat softens the horizon and watchful movement gathers near the last shade",
1060
+ ),
1061
+ (
1062
+ "workshop",
1063
+ "workshop lamps shine over ordered tools while sawdust settles in pale ribbons and each bench waits for deliberate hands",
1064
+ ),
1065
+ (
1066
+ "bridge",
1067
+ "bridge cables hold their tense geometry while river light drifts below and the roadway hums with disciplined forward motion",
1068
+ ),
1069
+ ]
1070
+ for theme, answer in composition_train:
1071
+ add_train(
1072
+ "composition",
1073
+ f"<reason> write {theme} scene in one paragraph <answer>",
1074
+ answer,
1075
+ sample=theme in {"ocean", "forest", "city", "harbor", "laboratory"},
1076
+ )
1077
+ add_train(
1078
+ "composition",
1079
+ f"<reason> write {theme} scene <answer>",
1080
+ answer,
1081
+ sample=False,
1082
+ )
1083
+ for theme, answer in composition_holdout:
1084
+ add_holdout(
1085
+ "composition",
1086
+ f"<reason> write {theme} scene in one paragraph <answer>",
1087
+ answer,
1088
+ )
1089
+ add_holdout(
1090
+ "composition",
1091
+ f"<reason> write {theme} scene <answer>",
1092
+ answer,
1093
+ )
1094
+
1095
+ add_open(
1096
+ "composition",
1097
+ "write harbor dawn scene with calm tension",
1098
+ [
1099
+ ["harbor", "port"],
1100
+ ["dawn", "morning", "sunrise", "light"],
1101
+ ["water", "tide", "shore"],
1102
+ ["calm", "quiet", "measured", "tension"],
1103
+ ],
1104
+ banned_phrases=[
1105
+ "harbor lights shimmer in patient water while cables rest against masts and slow bells mark the edge of another working night",
1106
+ ],
1107
+ min_words=16,
1108
+ max_tokens=40,
1109
+ )
1110
+ add_open(
1111
+ "composition",
1112
+ "write laboratory harbor scene with precise calm",
1113
+ [
1114
+ ["laboratory", "glass", "instrument"],
1115
+ ["harbor", "water", "mast", "cable"],
1116
+ ["calm", "quiet", "precise", "ordered"],
1117
+ ],
1118
+ banned_phrases=[],
1119
+ min_words=16,
1120
+ max_tokens=40,
1121
+ )
1122
+ add_open(
1123
+ "identity",
1124
+ "describe REFRAMR in your own words, with punctuation",
1125
+ [
1126
+ ["reframr"],
1127
+ ["okeymeta"],
1128
+ ["analytical", "recurrent", "language", "system"],
1129
+ ],
1130
+ banned_phrases=[
1131
+ "REFRAMR is an analytical recurrent language system built by OkeyMeta Ltd to compute structure from corpus evidence instead of gradient loops",
1132
+ "REFRAMR is analytical, recurrent, and deliberate; OkeyMeta Ltd builds it to compute structure from evidence, not gradient ritual.",
1133
+ ],
1134
+ min_words=12,
1135
+ max_tokens=36,
1136
+ )
1137
+ add_open(
1138
+ "exposition",
1139
+ "explain why punctuation helps long reasoning",
1140
+ [
1141
+ ["punctuation"],
1142
+ ["reasoning", "thinking"],
1143
+ ["structure", "pace", "pause", "closure"],
1144
+ ],
1145
+ banned_phrases=[
1146
+ "Punctuation supports long answers because structure must breathe: commas pace detail, semicolons balance related claims, and periods secure closure. Without marks, long prose often collapses into blur.",
1147
+ ],
1148
+ min_words=18,
1149
+ max_tokens=40,
1150
+ )
1151
+ add_open(
1152
+ "identity",
1153
+ "explain REFRAMR memory for long context in your own words",
1154
+ [
1155
+ ["reframr"],
1156
+ ["memory", "state"],
1157
+ ["context", "history"],
1158
+ ["long", "persistent", "extended"],
1159
+ ],
1160
+ banned_phrases=[
1161
+ "REFRAMR memory is a multi timescale analytical state that compresses history without quadratic attention",
1162
+ ],
1163
+ min_words=16,
1164
+ max_tokens=40,
1165
+ )
1166
+ add_open(
1167
+ "composition",
1168
+ "write archive bridge scene with reflective tension",
1169
+ [
1170
+ ["archive", "paper", "label", "memory"],
1171
+ ["bridge", "cable", "river", "roadway"],
1172
+ ["reflective", "tension", "quiet", "measured"],
1173
+ ],
1174
+ banned_phrases=[],
1175
+ min_words=16,
1176
+ max_tokens=40,
1177
+ )
1178
+
1179
+ return CorpusPackage(
1180
+ name="FrameCorpus-Foundation-v2",
1181
+ records=records,
1182
+ section_counts=section_counts,
1183
+ memorization_samples=_balanced_samples(memorization, 24),
1184
+ generalization_samples=_balanced_samples(generalization, 16),
1185
+ open_ended_samples=open_ended,
1186
+ )
1187
+
1188
+
1189
+ def build_generalization_corpus() -> CorpusPackage:
1190
+ foundation = build_foundation_corpus()
1191
+ allowed_sections = {
1192
+ "analogy",
1193
+ "paraphrase",
1194
+ "comparison",
1195
+ "causal",
1196
+ "definition",
1197
+ "identity",
1198
+ "exposition",
1199
+ "composition",
1200
+ }
1201
+
1202
+ records = [
1203
+ record
1204
+ for record in foundation.records
1205
+ if record.section in allowed_sections
1206
+ ]
1207
+ generalization = [
1208
+ sample
1209
+ for sample in foundation.generalization_samples
1210
+ if sample.section in allowed_sections
1211
+ ]
1212
+ open_ended = [
1213
+ sample
1214
+ for sample in foundation.open_ended_samples
1215
+ if sample.section in allowed_sections
1216
+ ]
1217
+
1218
+ return CorpusPackage(
1219
+ name="FrameCorpus-Generalization-v1",
1220
+ records=records,
1221
+ section_counts=_recount_sections(records),
1222
+ memorization_samples=[],
1223
+ generalization_samples=_balanced_samples(generalization, min(16, len(generalization))),
1224
+ open_ended_samples=open_ended,
1225
+ )
1226
+
1227
+
1228
+ def write_corpus_package(package: CorpusPackage, output_dir: str | Path) -> dict[str, str]:
1229
+ directory = Path(output_dir)
1230
+ directory.mkdir(parents=True, exist_ok=True)
1231
+
1232
+ base_filename = package.slug
1233
+ corpus_filename = f"{base_filename}.jsonl"
1234
+ manifest_filename = f"{base_filename}.manifest.json"
1235
+ prompt_suite_filename = f"{base_filename}.prompts.jsonl"
1236
+ corpus_path = directory / corpus_filename
1237
+ manifest_path = directory / manifest_filename
1238
+ prompt_suite_path = directory / prompt_suite_filename
1239
+
1240
+ corpus_path.write_text(
1241
+ "\n".join(json.dumps(record, ensure_ascii=True) for record in package.corpus_records()) + "\n",
1242
+ encoding="utf-8",
1243
+ )
1244
+ manifest_path.write_text(
1245
+ json.dumps(package.manifest(corpus_filename=corpus_filename), indent=2),
1246
+ encoding="utf-8",
1247
+ )
1248
+ prompt_suite_path.write_text(
1249
+ "\n".join(json.dumps(record, ensure_ascii=True) for record in package.prompt_suite()) + "\n",
1250
+ encoding="utf-8",
1251
+ )
1252
+
1253
+ return {
1254
+ "corpus_path": str(corpus_path.resolve()),
1255
+ "manifest_path": str(manifest_path.resolve()),
1256
+ "prompt_suite_path": str(prompt_suite_path.resolve()),
1257
+ }
reframr/curriculum.py ADDED
The diff for this file is too large to render. See raw diff
 
reframr/datasets.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ from .text_quality import clean_answer_text, clean_context_text, clean_training_text
5
+
6
+
7
+ TEXT_EXTENSIONS = {".txt", ".md", ".text"}
8
+ STRUCTURED_EXTENSIONS = {".jsonl", ".json"}
9
+
10
+
11
+ def _default_record_weight(record_type: str) -> int:
12
+ if record_type == "dialogue_turn":
13
+ return 2
14
+ if record_type == "instruction_answer":
15
+ return 2
16
+ if record_type == "preference_chosen":
17
+ return 3
18
+ if record_type == "preference_rejected":
19
+ return 0
20
+ return 1
21
+
22
+
23
+ def _record_repeat_count(record: object) -> int:
24
+ if not isinstance(record, dict):
25
+ return 1
26
+ if bool(record.get("drop")):
27
+ return 0
28
+ raw_weight = record.get("weight")
29
+ if raw_weight is not None:
30
+ try:
31
+ numeric = int(round(float(raw_weight)))
32
+ except (TypeError, ValueError):
33
+ numeric = 1
34
+ return max(0, min(8, numeric))
35
+ return _default_record_weight(str(record.get("record_type", "")))
36
+
37
+
38
+ def _coerce_text_record(record: object) -> str:
39
+ if isinstance(record, str):
40
+ return clean_training_text(record.strip())
41
+ if isinstance(record, dict):
42
+ if "text" in record:
43
+ return clean_training_text(str(record["text"]).strip())
44
+ if "content" in record:
45
+ return clean_training_text(str(record["content"]).strip())
46
+ if "context" in record and "answer" in record:
47
+ context = clean_context_text(str(record["context"]).strip())
48
+ answer = clean_answer_text(str(record["answer"]).strip())
49
+ if context and answer:
50
+ return f"<reason> {context} <answer> {answer}"
51
+ return ""
52
+
53
+
54
+ def _coerce_prompt_record(record: object) -> dict[str, object] | None:
55
+ if isinstance(record, str):
56
+ prompt = record.strip()
57
+ return {"prompt": prompt, "tags": []} if prompt else None
58
+ if isinstance(record, dict):
59
+ raw_prompt = record.get("prompt", record.get("context", ""))
60
+ prompt = clean_context_text(str(raw_prompt).strip())
61
+ if not prompt:
62
+ return None
63
+ raw_tags = record.get("tags", [])
64
+ tags = [str(tag) for tag in raw_tags] if isinstance(raw_tags, list) else []
65
+ normalized = dict(record)
66
+ normalized["prompt"] = prompt
67
+ normalized["tags"] = tags
68
+ return normalized
69
+ return None
70
+
71
+
72
+ def load_text_corpus(source: str | Path) -> str:
73
+ path = Path(source)
74
+ if path.is_dir():
75
+ parts = [
76
+ load_text_corpus(child)
77
+ for child in sorted(path.rglob("*"))
78
+ if child.is_file() and child.suffix.lower() in TEXT_EXTENSIONS | STRUCTURED_EXTENSIONS
79
+ ]
80
+ return "\n".join(part for part in parts if part.strip())
81
+
82
+ suffix = path.suffix.lower()
83
+ if suffix in TEXT_EXTENSIONS:
84
+ return path.read_text(encoding="utf-8")
85
+ if suffix == ".jsonl":
86
+ lines = []
87
+ for line in path.read_text(encoding="utf-8").splitlines():
88
+ if not line.strip():
89
+ continue
90
+ record = json.loads(line)
91
+ text = _coerce_text_record(record)
92
+ if text:
93
+ lines.extend([text] * _record_repeat_count(record))
94
+ return "\n".join(lines)
95
+ if suffix == ".json":
96
+ payload = json.loads(path.read_text(encoding="utf-8"))
97
+ if isinstance(payload, list):
98
+ parts: list[str] = []
99
+ for item in payload:
100
+ text = _coerce_text_record(item)
101
+ if text:
102
+ parts.extend([text] * _record_repeat_count(item))
103
+ return "\n".join(parts)
104
+ if isinstance(payload, dict):
105
+ if "texts" in payload and isinstance(payload["texts"], list):
106
+ parts: list[str] = []
107
+ for item in payload["texts"]:
108
+ text = _coerce_text_record(item)
109
+ if text:
110
+ parts.extend([text] * _record_repeat_count(item))
111
+ return "\n".join(parts)
112
+ if "records" in payload and isinstance(payload["records"], list):
113
+ parts: list[str] = []
114
+ for item in payload["records"]:
115
+ text = _coerce_text_record(item)
116
+ if text:
117
+ parts.extend([text] * _record_repeat_count(item))
118
+ return "\n".join(parts)
119
+ text = _coerce_text_record(payload)
120
+ if text:
121
+ return "\n".join([text] * _record_repeat_count(payload))
122
+ raise ValueError(f"Unsupported corpus source: {path}")
123
+
124
+
125
+ def load_prompt_suite(source: str | Path) -> list[dict[str, object]]:
126
+ path = Path(source)
127
+ suffix = path.suffix.lower()
128
+ prompts: list[dict[str, object]] = []
129
+
130
+ if suffix in TEXT_EXTENSIONS:
131
+ for line in path.read_text(encoding="utf-8").splitlines():
132
+ record = _coerce_prompt_record(line)
133
+ if record is not None:
134
+ prompts.append(record)
135
+ return prompts
136
+
137
+ if suffix == ".jsonl":
138
+ for line in path.read_text(encoding="utf-8").splitlines():
139
+ if not line.strip():
140
+ continue
141
+ record = _coerce_prompt_record(json.loads(line))
142
+ if record is not None:
143
+ prompts.append(record)
144
+ return prompts
145
+
146
+ if suffix == ".json":
147
+ payload = json.loads(path.read_text(encoding="utf-8"))
148
+ if isinstance(payload, list):
149
+ for item in payload:
150
+ record = _coerce_prompt_record(item)
151
+ if record is not None:
152
+ prompts.append(record)
153
+ return prompts
154
+ if isinstance(payload, dict):
155
+ if "prompts" in payload and isinstance(payload["prompts"], list):
156
+ for item in payload["prompts"]:
157
+ record = _coerce_prompt_record(item)
158
+ if record is not None:
159
+ prompts.append(record)
160
+ return prompts
161
+ record = _coerce_prompt_record(payload)
162
+ if record is not None:
163
+ return [record]
164
+
165
+ raise ValueError(f"Unsupported prompt suite: {path}")
reframr/embeddings.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+
6
+ from .corpus import build_cooccurrence_matrix, build_vocabulary, tokenize
7
+ from .linalg import Matrix, Vector, mean, np, top_k_eigenpairs_symmetric, zeros
8
+
9
+ try:
10
+ from scipy import sparse as scipy_sparse
11
+ from scipy.sparse.linalg import svds as scipy_svds
12
+ except (ImportError, ModuleNotFoundError, OSError):
13
+ scipy_sparse = None
14
+ scipy_svds = None
15
+
16
+
17
+ SKETCHED_EMBEDDING_VOCAB_THRESHOLD = 2048
18
+
19
+
20
+ def _remove_common_embedding_axis(embeddings: object, row_strength: object | None = None) -> object:
21
+ if np is None:
22
+ return embeddings
23
+ values = np.asarray(embeddings, dtype=np.float64)
24
+ if values.size == 0 or len(values.shape) != 2:
25
+ return values
26
+ norms = np.linalg.norm(values, axis=1)
27
+ nonzero = norms > 1e-12
28
+ values[nonzero] /= norms[nonzero, None]
29
+ if row_strength is not None:
30
+ strength = np.asarray(row_strength, dtype=np.float64)
31
+ if strength.shape[0] == values.shape[0]:
32
+ values[nonzero] *= np.log1p(strength[nonzero])[:, None]
33
+
34
+ common_axis = values.mean(axis=0, keepdims=True)
35
+ values = values - common_axis
36
+ norms = np.linalg.norm(values, axis=1)
37
+ nonzero = norms > 1e-12
38
+ values[nonzero] /= norms[nonzero, None]
39
+ if row_strength is not None:
40
+ strength = np.asarray(row_strength, dtype=np.float64)
41
+ if strength.shape[0] == values.shape[0]:
42
+ values[nonzero] *= np.log1p(strength[nonzero])[:, None]
43
+ return values
44
+
45
+
46
+ def _sketched_sparse_ppmi_embedding(ppmi: object, embedding_dim: int) -> object:
47
+ coo = ppmi.tocoo()
48
+ rows = coo.row.astype(np.int64, copy=False)
49
+ cols = coo.col.astype(np.int64, copy=False)
50
+ values = coo.data.astype(np.float64, copy=False)
51
+ embeddings = np.zeros((ppmi.shape[0], embedding_dim), dtype=np.float64)
52
+ if embedding_dim <= 0 or values.size == 0:
53
+ return embeddings
54
+
55
+ buckets = ((cols * 1103515245 + 12345) % embedding_dim).astype(np.int64, copy=False)
56
+ signs = np.where(((cols * 214013 + 2531011) & 1) == 0, 1.0, -1.0)
57
+ np.add.at(embeddings, (rows, buckets), values * signs)
58
+
59
+ row_strength = np.sqrt(np.asarray(ppmi.sum(axis=1)).ravel())
60
+ return _remove_common_embedding_axis(embeddings, row_strength)
61
+
62
+
63
+ def fit_sketched_ppmi_embedding_from_counts(
64
+ id_to_token: list[str],
65
+ rows: dict[int, dict[int, float]],
66
+ *,
67
+ embedding_dim: int,
68
+ ) -> EmbeddingModel:
69
+ if not id_to_token:
70
+ raise ValueError("Cannot fit REFRAMR embeddings without a vocabulary.")
71
+ if embedding_dim <= 0:
72
+ raise ValueError("Embedding dimension must be positive.")
73
+
74
+ size = len(id_to_token)
75
+ token_to_id = {token: index for index, token in enumerate(id_to_token)}
76
+ if np is None:
77
+ embeddings = zeros(size, embedding_dim)
78
+ row_sums = [0.0 for _ in range(size)]
79
+ for row, columns in rows.items():
80
+ row_sums[row] = sum(columns.values())
81
+ total = sum(row_sums)
82
+ if total <= 0.0:
83
+ return EmbeddingModel(token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[])
84
+ for row, columns in rows.items():
85
+ for col, count in columns.items():
86
+ denominator = row_sums[row] * row_sums[col]
87
+ if count <= 0.0 or denominator <= 0.0:
88
+ continue
89
+ value = math.log((count * total) / denominator)
90
+ if value <= 0.0:
91
+ continue
92
+ bucket = (col * 1103515245 + 12345) % embedding_dim
93
+ sign = 1.0 if ((col * 214013 + 2531011) & 1) == 0 else -1.0
94
+ embeddings[row][bucket] += value * sign
95
+ return EmbeddingModel(token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[])
96
+
97
+ embeddings = np.zeros((size, embedding_dim), dtype=np.float64)
98
+ row_sums = np.zeros(size, dtype=np.float64)
99
+ for row, columns in rows.items():
100
+ row_sums[row] = sum(columns.values())
101
+ total = float(row_sums.sum())
102
+ if total <= 0.0:
103
+ return EmbeddingModel(token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[])
104
+
105
+ for row, columns in rows.items():
106
+ if not columns or row_sums[row] <= 0.0:
107
+ continue
108
+ cols = np.fromiter(columns.keys(), dtype=np.int64)
109
+ counts = np.fromiter(columns.values(), dtype=np.float64)
110
+ denominators = row_sums[row] * row_sums[cols]
111
+ valid = (counts > 0.0) & (denominators > 0.0)
112
+ if not np.any(valid):
113
+ continue
114
+ cols = cols[valid]
115
+ values = np.log((counts[valid] * total) / denominators[valid])
116
+ positive = values > 0.0
117
+ if not np.any(positive):
118
+ continue
119
+ cols = cols[positive]
120
+ values = values[positive]
121
+ buckets = ((cols * 1103515245 + 12345) % embedding_dim).astype(np.int64, copy=False)
122
+ signs = np.where(((cols * 214013 + 2531011) & 1) == 0, 1.0, -1.0)
123
+ np.add.at(embeddings[row], buckets, values * signs)
124
+
125
+ embeddings = _remove_common_embedding_axis(embeddings, row_sums)
126
+ return EmbeddingModel(
127
+ token_to_id=token_to_id,
128
+ id_to_token=id_to_token,
129
+ embeddings=embeddings,
130
+ ppmi_matrix=[],
131
+ )
132
+
133
+
134
+ def _positive_ppmi_values(
135
+ *,
136
+ row: int,
137
+ columns: dict[int, float],
138
+ row_sums: object,
139
+ total: float,
140
+ ) -> tuple[object, object]:
141
+ cols = np.fromiter(columns.keys(), dtype=np.int64)
142
+ counts = np.fromiter(columns.values(), dtype=np.float64)
143
+ if cols.size == 0:
144
+ return cols, counts
145
+ denominators = float(row_sums[row]) * row_sums[cols]
146
+ valid = (counts > 0.0) & (denominators > 0.0)
147
+ if not np.any(valid):
148
+ return cols[:0], counts[:0]
149
+ cols = cols[valid]
150
+ values = np.log((counts[valid] * total) / denominators[valid])
151
+ positive = values > 0.0
152
+ return cols[positive], values[positive]
153
+
154
+
155
+ def fit_randomized_ppmi_embedding_from_counts(
156
+ id_to_token: list[str],
157
+ rows: dict[int, dict[int, float]],
158
+ *,
159
+ embedding_dim: int,
160
+ oversampling: int = 32,
161
+ ) -> EmbeddingModel:
162
+ if np is None:
163
+ return fit_sketched_ppmi_embedding_from_counts(
164
+ id_to_token,
165
+ rows,
166
+ embedding_dim=embedding_dim,
167
+ )
168
+ if not id_to_token:
169
+ raise ValueError("Cannot fit REFRAMR embeddings without a vocabulary.")
170
+ if embedding_dim <= 0:
171
+ raise ValueError("Embedding dimension must be positive.")
172
+
173
+ size = len(id_to_token)
174
+ token_to_id = {token: index for index, token in enumerate(id_to_token)}
175
+ row_sums = np.zeros(size, dtype=np.float64)
176
+ for row, columns in rows.items():
177
+ row_sums[row] = sum(columns.values())
178
+ total = float(row_sums.sum())
179
+ if total <= 0.0:
180
+ return EmbeddingModel(
181
+ token_to_id=token_to_id,
182
+ id_to_token=id_to_token,
183
+ embeddings=np.zeros((size, embedding_dim), dtype=np.float64),
184
+ ppmi_matrix=[],
185
+ )
186
+
187
+ width = min(size, max(embedding_dim, embedding_dim + oversampling))
188
+ rng = np.random.default_rng(1729 + size * 31 + embedding_dim)
189
+ omega = rng.standard_normal((size, width)).astype(np.float64, copy=False)
190
+ sketch = np.zeros((size, width), dtype=np.float64)
191
+ ppmi_cache: dict[int, tuple[object, object]] = {}
192
+ for row, columns in rows.items():
193
+ if not columns or row_sums[row] <= 0.0:
194
+ continue
195
+ cols, values = _positive_ppmi_values(
196
+ row=row,
197
+ columns=columns,
198
+ row_sums=row_sums,
199
+ total=total,
200
+ )
201
+ if values.size == 0:
202
+ continue
203
+ ppmi_cache[row] = (cols, values)
204
+ sketch[row] = values @ omega[cols]
205
+
206
+ if not ppmi_cache:
207
+ return EmbeddingModel(
208
+ token_to_id=token_to_id,
209
+ id_to_token=id_to_token,
210
+ embeddings=np.zeros((size, embedding_dim), dtype=np.float64),
211
+ ppmi_matrix=[],
212
+ )
213
+
214
+ basis, _ = np.linalg.qr(sketch, mode="reduced")
215
+ compressed = np.zeros((basis.shape[1], size), dtype=np.float64)
216
+ for row, (cols, values) in ppmi_cache.items():
217
+ compressed[:, cols] += basis[row, :, None] * values[None, :]
218
+
219
+ left_small, singular_values, _ = np.linalg.svd(compressed, full_matrices=False)
220
+ left = basis @ left_small
221
+ width = min(embedding_dim, left.shape[1], singular_values.shape[0])
222
+ embeddings = np.zeros((size, embedding_dim), dtype=np.float64)
223
+ if width > 0:
224
+ embeddings[:, :width] = left[:, :width] * np.sqrt(np.maximum(singular_values[:width], 0.0))[None, :]
225
+ embeddings = _remove_common_embedding_axis(embeddings, np.sqrt(row_sums))
226
+ return EmbeddingModel(
227
+ token_to_id=token_to_id,
228
+ id_to_token=id_to_token,
229
+ embeddings=embeddings,
230
+ ppmi_matrix=[],
231
+ )
232
+
233
+
234
+ def positive_pointwise_mutual_information(matrix: Matrix) -> Matrix:
235
+ if scipy_sparse is not None and scipy_sparse.issparse(matrix):
236
+ counts = matrix.tocoo()
237
+ if counts.nnz == 0:
238
+ return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64)
239
+ row_sums = np.asarray(matrix.sum(axis=1)).ravel()
240
+ total = float(row_sums.sum())
241
+ if total == 0.0:
242
+ return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64)
243
+ denominators = row_sums[counts.row] * row_sums[counts.col]
244
+ valid = (counts.data > 0.0) & (denominators > 0.0)
245
+ if not np.any(valid):
246
+ return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64)
247
+ ratios = (counts.data[valid] * total) / denominators[valid]
248
+ data = np.maximum(np.log(ratios), 0.0)
249
+ keep = data > 0.0
250
+ if not np.any(keep):
251
+ return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64)
252
+ return scipy_sparse.coo_matrix(
253
+ (
254
+ data[keep],
255
+ (counts.row[valid][keep], counts.col[valid][keep]),
256
+ ),
257
+ shape=counts.shape,
258
+ dtype=np.float64,
259
+ ).tocsr()
260
+
261
+ if not matrix:
262
+ return []
263
+ if np is not None:
264
+ counts = np.asarray(matrix, dtype=np.float64)
265
+ row_sums = counts.sum(axis=1)
266
+ total = float(row_sums.sum())
267
+ if total == 0.0:
268
+ return np.zeros_like(counts).tolist()
269
+ denominator = np.outer(row_sums, row_sums)
270
+ valid = (counts > 0.0) & (denominator > 0.0)
271
+ ppmi = np.zeros_like(counts)
272
+ with np.errstate(divide="ignore", invalid="ignore"):
273
+ ratios = np.divide(
274
+ counts * total,
275
+ denominator,
276
+ out=np.ones_like(counts),
277
+ where=valid,
278
+ )
279
+ ppmi[valid] = np.maximum(np.log(ratios[valid]), 0.0)
280
+ return ppmi.tolist()
281
+
282
+ row_sums = [sum(row) for row in matrix]
283
+ total = sum(row_sums)
284
+ if total == 0.0:
285
+ return zeros(len(matrix), len(matrix))
286
+
287
+ ppmi = zeros(len(matrix), len(matrix))
288
+ for row in range(len(matrix)):
289
+ for col in range(len(matrix[row])):
290
+ count = matrix[row][col]
291
+ if count <= 0.0 or row_sums[row] == 0.0 or row_sums[col] == 0.0:
292
+ continue
293
+ p_ij = count / total
294
+ p_i = row_sums[row] / total
295
+ p_j = row_sums[col] / total
296
+ value = math.log(p_ij / (p_i * p_j))
297
+ ppmi[row][col] = max(0.0, value)
298
+ return ppmi
299
+
300
+
301
+ @dataclass(slots=True)
302
+ class EmbeddingModel:
303
+ token_to_id: dict[str, int]
304
+ id_to_token: list[str]
305
+ embeddings: Matrix
306
+ ppmi_matrix: Matrix
307
+
308
+ def vector(self, token: str) -> Vector:
309
+ index = self.token_to_id.get(token)
310
+ if index is None and token.lower() != token:
311
+ index = self.token_to_id.get(token.lower())
312
+ if index is None:
313
+ return [0.0 for _ in range(self.dimension)]
314
+ row = self.embeddings[index]
315
+ return row.astype(float).tolist() if hasattr(row, "tolist") else row[:]
316
+
317
+ @property
318
+ def dimension(self) -> int:
319
+ if hasattr(self.embeddings, "shape"):
320
+ return int(self.embeddings.shape[1]) if len(self.embeddings.shape) > 1 else 0
321
+ return len(self.embeddings[0]) if self.embeddings else 0
322
+
323
+ @property
324
+ def projection_axis(self) -> Vector:
325
+ if hasattr(self.embeddings, "shape"):
326
+ if int(self.embeddings.shape[0]) == 0:
327
+ return []
328
+ return self.embeddings.mean(axis=0).astype(float).tolist()
329
+ if not self.embeddings:
330
+ return []
331
+ return [
332
+ mean([row[column] for row in self.embeddings])
333
+ for column in range(self.dimension)
334
+ ]
335
+
336
+
337
+ def fit_ppmi_embedding(
338
+ text: str,
339
+ *,
340
+ embedding_dim: int,
341
+ window_size: int,
342
+ min_frequency: int = 1,
343
+ max_vocab: int | None = None,
344
+ ) -> EmbeddingModel:
345
+ tokens = tokenize(text)
346
+ if not tokens:
347
+ raise ValueError("Cannot fit REFRAMR embeddings on empty text.")
348
+
349
+ return fit_ppmi_embedding_from_tokens(
350
+ tokens,
351
+ embedding_dim=embedding_dim,
352
+ window_size=window_size,
353
+ min_frequency=min_frequency,
354
+ max_vocab=max_vocab,
355
+ )
356
+
357
+
358
+ def fit_ppmi_embedding_from_tokens(
359
+ tokens: list[str],
360
+ *,
361
+ embedding_dim: int,
362
+ window_size: int,
363
+ min_frequency: int = 1,
364
+ max_vocab: int | None = None,
365
+ ) -> EmbeddingModel:
366
+ if not tokens:
367
+ raise ValueError("Cannot fit REFRAMR embeddings on an empty token stream.")
368
+
369
+ token_to_id, id_to_token = build_vocabulary(tokens, min_frequency, max_vocab)
370
+ cooccurrence = build_cooccurrence_matrix(tokens, token_to_id, window_size)
371
+ ppmi = positive_pointwise_mutual_information(cooccurrence)
372
+ eigenpairs = top_k_eigenpairs_symmetric(ppmi, embedding_dim)
373
+
374
+ embeddings = zeros(len(id_to_token), embedding_dim)
375
+ for component, (eigenvalue, eigenvector) in enumerate(eigenpairs):
376
+ scale = math.sqrt(max(eigenvalue, 0.0))
377
+ for row in range(len(id_to_token)):
378
+ embeddings[row][component] = eigenvector[row] * scale
379
+ if np is not None:
380
+ embeddings = _remove_common_embedding_axis(np.asarray(embeddings, dtype=np.float64))
381
+
382
+ return EmbeddingModel(
383
+ token_to_id=token_to_id,
384
+ id_to_token=id_to_token,
385
+ embeddings=embeddings,
386
+ ppmi_matrix=ppmi,
387
+ )
388
+
389
+
390
+ def fit_ppmi_embedding_from_cooccurrence(
391
+ id_to_token: list[str],
392
+ cooccurrence: Matrix,
393
+ *,
394
+ embedding_dim: int,
395
+ ) -> EmbeddingModel:
396
+ if not id_to_token:
397
+ raise ValueError("Cannot fit REFRAMR embeddings without a vocabulary.")
398
+
399
+ ppmi = positive_pointwise_mutual_information(cooccurrence)
400
+ if scipy_sparse is not None and scipy_sparse.issparse(ppmi):
401
+ embedding_width = min(embedding_dim, len(id_to_token))
402
+ if len(id_to_token) >= SKETCHED_EMBEDDING_VOCAB_THRESHOLD or embedding_width >= 128:
403
+ embeddings = _sketched_sparse_ppmi_embedding(ppmi, embedding_dim)
404
+ return EmbeddingModel(
405
+ token_to_id={token: index for index, token in enumerate(id_to_token)},
406
+ id_to_token=id_to_token,
407
+ embeddings=embeddings,
408
+ ppmi_matrix=[],
409
+ )
410
+ embeddings = zeros(len(id_to_token), embedding_dim)
411
+ if embedding_width <= 0 or ppmi.nnz == 0:
412
+ return EmbeddingModel(
413
+ token_to_id={token: index for index, token in enumerate(id_to_token)},
414
+ id_to_token=id_to_token,
415
+ embeddings=embeddings,
416
+ ppmi_matrix=[],
417
+ )
418
+ if embedding_width < min(ppmi.shape) and scipy_svds is not None:
419
+ left, values, _ = scipy_svds(ppmi.asfptype(), k=embedding_width, which="LM")
420
+ order = np.argsort(values)[::-1]
421
+ for component, source_index in enumerate(order):
422
+ scale = math.sqrt(max(float(values[source_index]), 0.0))
423
+ column = left[:, source_index]
424
+ for row, value in enumerate(column):
425
+ embeddings[row][component] = float(value) * scale
426
+ else:
427
+ dense = ppmi.toarray().tolist()
428
+ eigenpairs = top_k_eigenpairs_symmetric(dense, embedding_width)
429
+ for component, (eigenvalue, eigenvector) in enumerate(eigenpairs):
430
+ scale = math.sqrt(max(eigenvalue, 0.0))
431
+ for row in range(len(id_to_token)):
432
+ embeddings[row][component] = eigenvector[row] * scale
433
+ if np is not None:
434
+ embeddings = _remove_common_embedding_axis(np.asarray(embeddings, dtype=np.float64))
435
+ return EmbeddingModel(
436
+ token_to_id={token: index for index, token in enumerate(id_to_token)},
437
+ id_to_token=id_to_token,
438
+ embeddings=embeddings,
439
+ ppmi_matrix=[],
440
+ )
441
+
442
+ eigenpairs = top_k_eigenpairs_symmetric(ppmi, embedding_dim)
443
+
444
+ embeddings = zeros(len(id_to_token), embedding_dim)
445
+ for component, (eigenvalue, eigenvector) in enumerate(eigenpairs):
446
+ scale = math.sqrt(max(eigenvalue, 0.0))
447
+ for row in range(len(id_to_token)):
448
+ embeddings[row][component] = eigenvector[row] * scale
449
+ if np is not None:
450
+ embeddings = _remove_common_embedding_axis(np.asarray(embeddings, dtype=np.float64))
451
+
452
+ return EmbeddingModel(
453
+ token_to_id={token: index for index, token in enumerate(id_to_token)},
454
+ id_to_token=id_to_token,
455
+ embeddings=embeddings,
456
+ ppmi_matrix=ppmi,
457
+ )
reframr/evaluation.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ from .model import ReframrModel
5
+
6
+
7
+ def load_manifest(path: str | Path) -> dict[str, object]:
8
+ return json.loads(Path(path).read_text(encoding="utf-8"))
9
+
10
+
11
+ def _expected_next_token(model: ReframrModel, expected_text: str) -> str:
12
+ assert model.tokenizer is not None
13
+ encoded = model.tokenizer.encode(f" {expected_text}")
14
+ return encoded[0] if encoded else ""
15
+
16
+
17
+ def _normalize_text(text: str) -> str:
18
+ return " ".join(text.casefold().split())
19
+
20
+
21
+ def _word_ngrams(words: list[str], size: int) -> list[tuple[str, ...]]:
22
+ if size <= 0 or len(words) < size:
23
+ return []
24
+ return [tuple(words[index : index + size]) for index in range(len(words) - size + 1)]
25
+
26
+
27
+ def _distinct_ratio(words: list[str], size: int) -> float:
28
+ grams = _word_ngrams(words, size)
29
+ if not grams:
30
+ return 0.0
31
+ return len(set(grams)) / len(grams)
32
+
33
+
34
+ def _repetition_ratio(words: list[str], size: int) -> float:
35
+ grams = _word_ngrams(words, size)
36
+ if not grams:
37
+ return 0.0
38
+ repeated = len(grams) - len(set(grams))
39
+ return repeated / len(grams)
40
+
41
+
42
+ def _open_ended_score(
43
+ model: ReframrModel,
44
+ sample: dict[str, object],
45
+ *,
46
+ reasoning_mode: str | None,
47
+ ) -> dict[str, object]:
48
+ generated = model.generate_text(
49
+ str(sample["context"]),
50
+ max_tokens=int(sample.get("max_tokens", 56)),
51
+ reasoning_mode=reasoning_mode,
52
+ )
53
+ normalized = _normalize_text(generated)
54
+ required_groups = [
55
+ [str(term).casefold() for term in group]
56
+ for group in sample.get("required_groups", [])
57
+ ]
58
+ satisfied_groups = sum(
59
+ 1
60
+ for group in required_groups
61
+ if any(term in normalized for term in group)
62
+ )
63
+ group_coverage = (
64
+ satisfied_groups / len(required_groups) if required_groups else 0.0
65
+ )
66
+ punctuation_hit = any(mark in generated for mark in ".,;:?!")
67
+ min_words = int(sample.get("min_words", 12))
68
+ min_word_hit = len(generated.split()) >= min_words
69
+ banned_phrases = [str(phrase) for phrase in sample.get("banned_phrases", [])]
70
+ exact_copy = any(normalized == _normalize_text(phrase) for phrase in banned_phrases)
71
+ novelty_hit = not exact_copy
72
+ require_punctuation = bool(sample.get("require_punctuation", True))
73
+
74
+ score_components = [
75
+ group_coverage,
76
+ 1.0 if min_word_hit else 0.0,
77
+ 1.0 if novelty_hit else 0.0,
78
+ ]
79
+ if require_punctuation:
80
+ score_components.append(1.0 if punctuation_hit else 0.0)
81
+
82
+ return {
83
+ "section": str(sample["section"]),
84
+ "context": str(sample["context"]),
85
+ "generated_text": generated,
86
+ "group_coverage": group_coverage,
87
+ "punctuation_hit": punctuation_hit,
88
+ "min_word_hit": min_word_hit,
89
+ "exact_copy": exact_copy,
90
+ "score": sum(score_components) / len(score_components) if score_components else 0.0,
91
+ }
92
+
93
+
94
+ def evaluate_manifest(
95
+ model: ReframrModel,
96
+ manifest: dict[str, object],
97
+ *,
98
+ reasoning_mode: str | None = None,
99
+ top_k: int = 5,
100
+ ) -> dict[str, object]:
101
+ results: dict[str, object] = {
102
+ "corpus_name": manifest["name"],
103
+ "reasoning_mode": reasoning_mode or model.config.default_reasoning_profile,
104
+ "splits": {},
105
+ }
106
+
107
+ splits = manifest["splits"]
108
+ for split_name in ("memorization", "generalization"):
109
+ samples = splits[split_name]
110
+ top1_hits = 0
111
+ topk_hits = 0
112
+ expected_probabilities = []
113
+
114
+ for sample in samples:
115
+ distribution = model.predict_next_token_distribution(
116
+ sample["context"],
117
+ reasoning_mode=reasoning_mode,
118
+ )
119
+ ranked = sorted(distribution.items(), key=lambda item: item[1], reverse=True)
120
+ predicted = ranked[0][0] if ranked else ""
121
+ top_tokens = [token for token, _ in ranked[:top_k]]
122
+ expected = _expected_next_token(model, sample["expected"])
123
+ expected_probability = distribution.get(expected, 0.0)
124
+
125
+ if predicted == expected:
126
+ top1_hits += 1
127
+ if expected in top_tokens:
128
+ topk_hits += 1
129
+ expected_probabilities.append(expected_probability)
130
+
131
+ sample_count = len(samples)
132
+ mean_expected_probability = (
133
+ sum(expected_probabilities) / sample_count if sample_count else 0.0
134
+ )
135
+ results["splits"][split_name] = {
136
+ "sample_count": sample_count,
137
+ "top1_accuracy": top1_hits / sample_count if sample_count else 0.0,
138
+ "topk_accuracy": topk_hits / sample_count if sample_count else 0.0,
139
+ "mean_expected_probability": mean_expected_probability,
140
+ }
141
+
142
+ open_ended_samples = splits.get("open_ended", [])
143
+ if open_ended_samples:
144
+ sample_results = [
145
+ _open_ended_score(
146
+ model,
147
+ sample,
148
+ reasoning_mode=reasoning_mode,
149
+ )
150
+ for sample in open_ended_samples
151
+ ]
152
+ sample_count = len(sample_results)
153
+ results["open_ended"] = {
154
+ "sample_count": sample_count,
155
+ "mean_score": (
156
+ sum(float(sample["score"]) for sample in sample_results) / sample_count
157
+ if sample_count
158
+ else 0.0
159
+ ),
160
+ "mean_group_coverage": (
161
+ sum(float(sample["group_coverage"]) for sample in sample_results) / sample_count
162
+ if sample_count
163
+ else 0.0
164
+ ),
165
+ "punctuation_rate": (
166
+ sum(1 for sample in sample_results if bool(sample["punctuation_hit"])) / sample_count
167
+ if sample_count
168
+ else 0.0
169
+ ),
170
+ "min_word_rate": (
171
+ sum(1 for sample in sample_results if bool(sample["min_word_hit"])) / sample_count
172
+ if sample_count
173
+ else 0.0
174
+ ),
175
+ "exact_copy_rate": (
176
+ sum(1 for sample in sample_results if bool(sample["exact_copy"])) / sample_count
177
+ if sample_count
178
+ else 0.0
179
+ ),
180
+ "samples": sample_results,
181
+ }
182
+
183
+ return results
184
+
185
+
186
+ def benchmark_open_prompts(
187
+ model: ReframrModel,
188
+ prompts: list[dict[str, object]],
189
+ *,
190
+ reasoning_mode: str | None = None,
191
+ max_tokens: int = 64,
192
+ temperature: float = 0.82,
193
+ top_k: int = 24,
194
+ top_p: float = 0.92,
195
+ repetition_penalty: float = 1.18,
196
+ ) -> dict[str, object]:
197
+ samples: list[dict[str, object]] = []
198
+ for item in prompts:
199
+ prompt = str(item["prompt"])
200
+ generated = model.generate_text(
201
+ prompt,
202
+ max_tokens=max_tokens,
203
+ reasoning_mode=reasoning_mode,
204
+ temperature=temperature,
205
+ top_k=top_k,
206
+ top_p=top_p,
207
+ repetition_penalty=repetition_penalty,
208
+ )
209
+ words = generated.split()
210
+ samples.append(
211
+ {
212
+ "prompt": prompt,
213
+ "tags": [str(tag) for tag in item.get("tags", [])],
214
+ "generated_text": generated,
215
+ "word_count": len(words),
216
+ "char_count": len(generated),
217
+ "punctuation_hit": any(mark in generated for mark in ".,;:?!"),
218
+ "distinct_2": _distinct_ratio(words, 2),
219
+ "distinct_3": _distinct_ratio(words, 3),
220
+ "repetition_3": _repetition_ratio(words, 3),
221
+ }
222
+ )
223
+
224
+ sample_count = len(samples)
225
+ return {
226
+ "sample_count": sample_count,
227
+ "reasoning_mode": reasoning_mode or model.config.default_reasoning_profile,
228
+ "generation_policy": {
229
+ "temperature": temperature,
230
+ "top_k": top_k,
231
+ "top_p": top_p,
232
+ "repetition_penalty": repetition_penalty,
233
+ },
234
+ "mean_word_count": (
235
+ sum(int(sample["word_count"]) for sample in samples) / sample_count
236
+ if sample_count
237
+ else 0.0
238
+ ),
239
+ "mean_char_count": (
240
+ sum(int(sample["char_count"]) for sample in samples) / sample_count
241
+ if sample_count
242
+ else 0.0
243
+ ),
244
+ "punctuation_rate": (
245
+ sum(1 for sample in samples if bool(sample["punctuation_hit"])) / sample_count
246
+ if sample_count
247
+ else 0.0
248
+ ),
249
+ "mean_distinct_2": (
250
+ sum(float(sample["distinct_2"]) for sample in samples) / sample_count
251
+ if sample_count
252
+ else 0.0
253
+ ),
254
+ "mean_distinct_3": (
255
+ sum(float(sample["distinct_3"]) for sample in samples) / sample_count
256
+ if sample_count
257
+ else 0.0
258
+ ),
259
+ "mean_repetition_3": (
260
+ sum(float(sample["repetition_3"]) for sample in samples) / sample_count
261
+ if sample_count
262
+ else 0.0
263
+ ),
264
+ "samples": samples,
265
+ }
reframr/hf_import.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import site
4
+ import sys
5
+ from itertools import chain
6
+ from pathlib import Path
7
+
8
+ from .text_quality import clean_answer_text, clean_context_text, clean_training_text
9
+
10
+ TEXT_FIELD_PREFERENCES = (
11
+ "text",
12
+ "content",
13
+ "body",
14
+ "article",
15
+ "document",
16
+ "passage",
17
+ "markdown",
18
+ )
19
+
20
+ DIALOGUE_FIELD_PREFERENCES = (
21
+ "messages",
22
+ "conversation",
23
+ "conversations",
24
+ "dialogue",
25
+ "dialog",
26
+ "turns",
27
+ )
28
+
29
+ PREFERENCE_FIELD_PAIRS = (
30
+ ("chosen", "rejected"),
31
+ ("response_j", "response_k"),
32
+ ("response_0", "response_1"),
33
+ )
34
+
35
+ INSTRUCTION_FIELD_PAIRS = (
36
+ ("instruction", "output"),
37
+ ("prompt", "completion"),
38
+ ("prompt", "response"),
39
+ ("question", "answer"),
40
+ ("question", "response"),
41
+ ("query", "response"),
42
+ )
43
+
44
+ TRANSCRIPT_ROLE_PATTERN = re.compile(r"(?:^|\n\s*\n)(Human|Assistant|System)\s*:\s*", re.IGNORECASE)
45
+ ROLE_ALIASES = {
46
+ "assistant": "assistant",
47
+ "bot": "assistant",
48
+ "gpt": "assistant",
49
+ "model": "assistant",
50
+ "assistant_response": "assistant",
51
+ "human": "user",
52
+ "user": "user",
53
+ "prompter": "user",
54
+ "customer": "user",
55
+ "system": "system",
56
+ }
57
+
58
+
59
+ def _word_count(text: str) -> int:
60
+ return len(text.split())
61
+
62
+
63
+ def _alpha_ratio(text: str) -> float:
64
+ if not text:
65
+ return 0.0
66
+ alpha_count = sum(character.isalpha() for character in text)
67
+ return alpha_count / len(text)
68
+
69
+
70
+ def _default_record_weight(record_type: str) -> int:
71
+ if record_type == "dialogue_turn":
72
+ return 2
73
+ if record_type == "instruction_answer":
74
+ return 2
75
+ if record_type == "preference_chosen":
76
+ return 3
77
+ if record_type == "preference_rejected":
78
+ return 0
79
+ return 1
80
+
81
+
82
+ def choose_text_field(columns: list[str]) -> str:
83
+ normalized = {column.casefold(): column for column in columns}
84
+ for preferred in TEXT_FIELD_PREFERENCES:
85
+ if preferred in normalized:
86
+ return normalized[preferred]
87
+ raise ValueError("Could not infer a text column. Pass --text-field explicitly.")
88
+
89
+
90
+ def choose_dialogue_field(columns: list[str]) -> str:
91
+ normalized = {column.casefold(): column for column in columns}
92
+ for preferred in DIALOGUE_FIELD_PREFERENCES:
93
+ if preferred in normalized:
94
+ return normalized[preferred]
95
+ raise ValueError("Could not infer a conversation column.")
96
+
97
+
98
+ def choose_preference_fields(columns: list[str]) -> tuple[str, str]:
99
+ normalized = {column.casefold(): column for column in columns}
100
+ for chosen_name, rejected_name in PREFERENCE_FIELD_PAIRS:
101
+ if chosen_name in normalized and rejected_name in normalized:
102
+ return normalized[chosen_name], normalized[rejected_name]
103
+ raise ValueError("Could not infer chosen/rejected preference columns.")
104
+
105
+
106
+ def choose_instruction_fields(columns: list[str]) -> tuple[str, str]:
107
+ normalized = {column.casefold(): column for column in columns}
108
+ for prompt_name, answer_name in INSTRUCTION_FIELD_PAIRS:
109
+ if prompt_name in normalized and answer_name in normalized:
110
+ return normalized[prompt_name], normalized[answer_name]
111
+ raise ValueError("Could not infer instruction/answer columns.")
112
+
113
+
114
+ def _row_identifier(row: dict[str, object]) -> str:
115
+ for candidate in ("id", "_id", "row_id", "uuid", "prompt_id"):
116
+ if candidate in row and str(row[candidate]).strip():
117
+ return str(row[candidate]).strip()
118
+ return ""
119
+
120
+
121
+ def _base_record(
122
+ *,
123
+ dataset: str,
124
+ config: str | None,
125
+ split: str,
126
+ row_id: str,
127
+ ) -> dict[str, str]:
128
+ return {
129
+ "source": "huggingface",
130
+ "dataset": dataset,
131
+ "config": config or "",
132
+ "split": split,
133
+ "row_id": row_id,
134
+ }
135
+
136
+
137
+ def _row_language(row: dict[str, object]) -> str:
138
+ for candidate in ("lang", "language", "locale"):
139
+ value = row.get(candidate)
140
+ if isinstance(value, str) and value.strip():
141
+ return value.strip()
142
+ return ""
143
+
144
+
145
+ def _normalize_role(raw_role: object) -> str:
146
+ role = str(raw_role or "").strip().casefold()
147
+ return ROLE_ALIASES.get(role, role)
148
+
149
+
150
+ def _message_content(message: dict[str, object]) -> str:
151
+ for field in ("content", "value", "text", "message"):
152
+ value = message.get(field)
153
+ if isinstance(value, str) and value.strip():
154
+ return clean_training_text(value)
155
+ return ""
156
+
157
+
158
+ def _message_role(message: dict[str, object]) -> str:
159
+ for field in ("role", "from", "speaker", "author"):
160
+ value = message.get(field)
161
+ if value is not None:
162
+ normalized = _normalize_role(value)
163
+ if normalized:
164
+ return normalized
165
+ return ""
166
+
167
+
168
+ def _parse_dialogue_messages(raw_messages: object) -> list[dict[str, str]]:
169
+ if not isinstance(raw_messages, list):
170
+ return []
171
+
172
+ parsed: list[dict[str, str]] = []
173
+ for message in raw_messages:
174
+ if not isinstance(message, dict):
175
+ continue
176
+ role = _message_role(message)
177
+ content = _message_content(message)
178
+ if role not in {"system", "user", "assistant"} or not content:
179
+ continue
180
+ parsed.append({"role": role, "content": content})
181
+ return parsed
182
+
183
+
184
+ def _parse_transcript_messages(raw_text: object) -> list[dict[str, str]]:
185
+ if not isinstance(raw_text, str):
186
+ return []
187
+
188
+ text = raw_text.strip()
189
+ if not text:
190
+ return []
191
+
192
+ matches = list(TRANSCRIPT_ROLE_PATTERN.finditer(text))
193
+ if not matches:
194
+ return []
195
+
196
+ parsed: list[dict[str, str]] = []
197
+ for index, match in enumerate(matches):
198
+ role = _normalize_role(match.group(1))
199
+ start = match.end()
200
+ end = matches[index + 1].start() if index + 1 < len(matches) else len(text)
201
+ content = clean_training_text(text[start:end].strip())
202
+ if role in {"system", "user", "assistant"} and content:
203
+ parsed.append({"role": role, "content": content})
204
+ return parsed
205
+
206
+
207
+ def _render_prompt(messages: list[dict[str, str]]) -> str:
208
+ lines = []
209
+ for message in messages:
210
+ content = clean_context_text(message["content"])
211
+ if content:
212
+ lines.append(content)
213
+ return "\n".join(lines).strip()
214
+
215
+
216
+ def _compose_training_text(context: str, answer: str) -> str:
217
+ context = clean_context_text(context)
218
+ answer = clean_answer_text(answer)
219
+ return f"<reason> {context} <answer> {answer}".strip()
220
+
221
+
222
+ def _compose_instruction_context(row: dict[str, object], prompt_field: str) -> str:
223
+ parts: list[str] = []
224
+ prompt = clean_context_text(str(row.get(prompt_field, "")).strip())
225
+ extra_input = clean_context_text(str(row.get("input", "")).strip())
226
+ if prompt:
227
+ parts.append(prompt)
228
+ if extra_input:
229
+ parts.append(extra_input)
230
+ return "\n".join(parts).strip()
231
+
232
+
233
+ def _extract_prompt_answer(
234
+ row: dict[str, object],
235
+ *,
236
+ field_name: str,
237
+ ) -> tuple[str, str]:
238
+ dialogue_messages = _parse_dialogue_messages(row.get(field_name))
239
+ if dialogue_messages and dialogue_messages[-1]["role"] == "assistant":
240
+ prompt = _render_prompt(dialogue_messages[:-1])
241
+ answer = dialogue_messages[-1]["content"]
242
+ if prompt and answer:
243
+ return prompt, answer
244
+
245
+ messages = _parse_transcript_messages(row.get(field_name))
246
+ if messages:
247
+ if messages[-1]["role"] == "assistant":
248
+ prompt = _render_prompt(messages[:-1])
249
+ answer = messages[-1]["content"]
250
+ if prompt and answer:
251
+ return prompt, answer
252
+
253
+ prompt = clean_training_text(str(row.get("prompt", row.get("question", ""))).strip())
254
+ answer = clean_answer_text(str(row.get(field_name, "")).strip())
255
+ return prompt, answer
256
+
257
+
258
+ def _ordered_preference_fields(
259
+ row: dict[str, object],
260
+ *,
261
+ left_field: str,
262
+ right_field: str,
263
+ ) -> tuple[str, str]:
264
+ if {left_field, right_field} != {"response_0", "response_1"}:
265
+ return left_field, right_field
266
+
267
+ for selector in ("safer_response_id", "better_response_id"):
268
+ value = row.get(selector)
269
+ try:
270
+ preferred = int(value)
271
+ except (TypeError, ValueError):
272
+ continue
273
+ if preferred == 0:
274
+ return "response_0", "response_1"
275
+ if preferred == 1:
276
+ return "response_1", "response_0"
277
+ return left_field, right_field
278
+
279
+
280
+ def _passes_quality_gate(
281
+ record: dict[str, str],
282
+ *,
283
+ min_words: int,
284
+ max_words: int,
285
+ min_alpha_ratio: float,
286
+ allowed_languages: set[str],
287
+ ) -> bool:
288
+ candidate = str(record.get("answer") or record.get("text") or "").strip()
289
+ if not candidate:
290
+ return False
291
+
292
+ word_count = _word_count(candidate)
293
+ if min_words > 0 and word_count < min_words:
294
+ return False
295
+ if max_words > 0 and word_count > max_words:
296
+ return False
297
+
298
+ alpha_ratio = _alpha_ratio(candidate)
299
+ if min_alpha_ratio > 0.0 and alpha_ratio < min_alpha_ratio:
300
+ return False
301
+
302
+ if allowed_languages:
303
+ language = str(record.get("language", "")).strip().casefold()
304
+ if not language or language not in allowed_languages:
305
+ return False
306
+
307
+ record["quality_word_count"] = str(word_count)
308
+ record["quality_alpha_ratio"] = f"{alpha_ratio:.4f}"
309
+ return True
310
+
311
+
312
+ def to_json_record(
313
+ *,
314
+ dataset: str,
315
+ config: str | None,
316
+ split: str,
317
+ text_field: str,
318
+ row: dict[str, object],
319
+ ) -> dict[str, str]:
320
+ text = clean_training_text(str(row.get(text_field, "")).strip())
321
+ if not text:
322
+ raise ValueError("Row is missing usable text.")
323
+
324
+ record_type = "text"
325
+ return {
326
+ **_base_record(
327
+ dataset=dataset,
328
+ config=config,
329
+ split=split,
330
+ row_id=_row_identifier(row),
331
+ ),
332
+ "record_type": record_type,
333
+ "language": _row_language(row),
334
+ "text_field": text_field,
335
+ "text": text,
336
+ "word_count": _word_count(text),
337
+ "weight": _default_record_weight(record_type),
338
+ }
339
+
340
+
341
+ def dialogue_to_json_records(
342
+ *,
343
+ dataset: str,
344
+ config: str | None,
345
+ split: str,
346
+ conversation_field: str,
347
+ row: dict[str, object],
348
+ ) -> list[dict[str, str]]:
349
+ messages = _parse_dialogue_messages(row.get(conversation_field))
350
+ if not messages:
351
+ raise ValueError("Row does not contain usable dialogue turns.")
352
+
353
+ row_id = _row_identifier(row)
354
+ records: list[dict[str, str]] = []
355
+ history: list[dict[str, str]] = []
356
+ row_language = _row_language(row)
357
+ system_text = clean_training_text(str(row.get("system", "")).strip())
358
+ if system_text:
359
+ history.append({"role": "system", "content": system_text})
360
+ assistant_turn_index = 0
361
+ for message in messages:
362
+ if message["role"] != "assistant":
363
+ history.append(message)
364
+ continue
365
+ prompt = _render_prompt(history)
366
+ if not prompt:
367
+ continue
368
+ assistant_turn_index += 1
369
+ records.append(
370
+ {
371
+ **_base_record(
372
+ dataset=dataset,
373
+ config=config,
374
+ split=split,
375
+ row_id=row_id,
376
+ ),
377
+ "record_type": "dialogue_turn",
378
+ "language": row_language,
379
+ "conversation_field": conversation_field,
380
+ "turn_index": str(assistant_turn_index),
381
+ "context": prompt,
382
+ "answer": clean_answer_text(message["content"]),
383
+ "text": _compose_training_text(prompt, message["content"]),
384
+ "word_count": _word_count(clean_answer_text(message["content"])),
385
+ "weight": _default_record_weight("dialogue_turn"),
386
+ }
387
+ )
388
+ history.append(message)
389
+
390
+ if not records:
391
+ raise ValueError("Dialogue row did not yield any assistant training turns.")
392
+ return records
393
+
394
+
395
+ def preference_to_json_records(
396
+ *,
397
+ dataset: str,
398
+ config: str | None,
399
+ split: str,
400
+ chosen_field: str,
401
+ rejected_field: str,
402
+ row: dict[str, object],
403
+ preference_target: str = "both",
404
+ ) -> list[dict[str, str]]:
405
+ row_id = _row_identifier(row)
406
+ pair_id = row_id or f"{chosen_field}:{rejected_field}"
407
+ records: list[dict[str, str]] = []
408
+ row_language = _row_language(row)
409
+ chosen_field, rejected_field = _ordered_preference_fields(
410
+ row,
411
+ left_field=chosen_field,
412
+ right_field=rejected_field,
413
+ )
414
+
415
+ field_specs = [
416
+ (chosen_field, "preference_chosen"),
417
+ (rejected_field, "preference_rejected"),
418
+ ]
419
+ if preference_target == "chosen":
420
+ field_specs = [(chosen_field, "preference_chosen")]
421
+ elif preference_target == "rejected":
422
+ field_specs = [(rejected_field, "preference_rejected")]
423
+ elif preference_target != "both":
424
+ raise ValueError("preference_target must be one of: both, chosen, rejected.")
425
+
426
+ for field_name, record_type in field_specs:
427
+ prompt, answer = _extract_prompt_answer(row, field_name=field_name)
428
+ if not prompt or not answer:
429
+ continue
430
+ records.append(
431
+ {
432
+ **_base_record(
433
+ dataset=dataset,
434
+ config=config,
435
+ split=split,
436
+ row_id=row_id,
437
+ ),
438
+ "record_type": record_type,
439
+ "language": row_language,
440
+ "pair_id": pair_id,
441
+ "text_field": field_name,
442
+ "context": prompt,
443
+ "answer": clean_answer_text(answer),
444
+ "text": _compose_training_text(prompt, answer),
445
+ "word_count": _word_count(clean_answer_text(answer)),
446
+ "weight": _default_record_weight(record_type),
447
+ }
448
+ )
449
+
450
+ if not records:
451
+ raise ValueError("Preference row did not yield usable chosen/rejected transcripts.")
452
+ return records
453
+
454
+
455
+ def instruction_to_json_records(
456
+ *,
457
+ dataset: str,
458
+ config: str | None,
459
+ split: str,
460
+ prompt_field: str,
461
+ answer_field: str,
462
+ row: dict[str, object],
463
+ ) -> list[dict[str, str]]:
464
+ context = _compose_instruction_context(row, prompt_field)
465
+ answer = clean_answer_text(str(row.get(answer_field, "")).strip())
466
+ if not context or not answer:
467
+ raise ValueError("Instruction row did not contain usable prompt and answer text.")
468
+ record_type = "instruction_answer"
469
+ return [
470
+ {
471
+ **_base_record(
472
+ dataset=dataset,
473
+ config=config,
474
+ split=split,
475
+ row_id=_row_identifier(row),
476
+ ),
477
+ "record_type": record_type,
478
+ "language": _row_language(row),
479
+ "context": context,
480
+ "answer": answer,
481
+ "text": _compose_training_text(context, answer),
482
+ "word_count": _word_count(answer),
483
+ "weight": _default_record_weight(record_type),
484
+ }
485
+ ]
486
+
487
+
488
+ def _expand_row_records(
489
+ *,
490
+ dataset: str,
491
+ config: str | None,
492
+ split: str,
493
+ row: dict[str, object],
494
+ text_field: str | None,
495
+ preference_target: str,
496
+ ) -> list[dict[str, str]]:
497
+ if text_field is not None:
498
+ explicit_value = row.get(text_field)
499
+ if isinstance(explicit_value, list):
500
+ return dialogue_to_json_records(
501
+ dataset=dataset,
502
+ config=config,
503
+ split=split,
504
+ conversation_field=text_field,
505
+ row=row,
506
+ )
507
+ return [
508
+ to_json_record(
509
+ dataset=dataset,
510
+ config=config,
511
+ split=split,
512
+ text_field=text_field,
513
+ row=row,
514
+ )
515
+ ]
516
+
517
+ columns = list(row)
518
+ try:
519
+ chosen_field, rejected_field = choose_preference_fields(columns)
520
+ return preference_to_json_records(
521
+ dataset=dataset,
522
+ config=config,
523
+ split=split,
524
+ chosen_field=chosen_field,
525
+ rejected_field=rejected_field,
526
+ row=row,
527
+ preference_target=preference_target,
528
+ )
529
+ except ValueError:
530
+ pass
531
+
532
+ try:
533
+ prompt_field, answer_field = choose_instruction_fields(columns)
534
+ return instruction_to_json_records(
535
+ dataset=dataset,
536
+ config=config,
537
+ split=split,
538
+ prompt_field=prompt_field,
539
+ answer_field=answer_field,
540
+ row=row,
541
+ )
542
+ except ValueError:
543
+ pass
544
+
545
+ try:
546
+ conversation_field = choose_dialogue_field(columns)
547
+ if isinstance(row.get(conversation_field), list):
548
+ return dialogue_to_json_records(
549
+ dataset=dataset,
550
+ config=config,
551
+ split=split,
552
+ conversation_field=conversation_field,
553
+ row=row,
554
+ )
555
+ except ValueError:
556
+ pass
557
+
558
+ inferred_text_field = choose_text_field(columns)
559
+ return [
560
+ to_json_record(
561
+ dataset=dataset,
562
+ config=config,
563
+ split=split,
564
+ text_field=inferred_text_field,
565
+ row=row,
566
+ )
567
+ ]
568
+
569
+
570
+ def import_hf_dataset(
571
+ *,
572
+ dataset: str,
573
+ output_path: str | Path,
574
+ config: str | None = None,
575
+ split: str = "train",
576
+ text_field: str | None = None,
577
+ limit: int = 1000,
578
+ streaming: bool = True,
579
+ preference_target: str = "chosen",
580
+ min_words: int = 0,
581
+ max_words: int = 0,
582
+ min_alpha_ratio: float = 0.0,
583
+ allowed_languages: tuple[str, ...] = (),
584
+ ) -> dict[str, object]:
585
+ try:
586
+ from datasets import load_dataset
587
+ except ModuleNotFoundError:
588
+ user_site = site.getusersitepackages()
589
+ if user_site and user_site not in sys.path:
590
+ sys.path.append(user_site)
591
+ from datasets import load_dataset
592
+
593
+ dataset_kwargs: dict[str, object] = {
594
+ "split": split,
595
+ "streaming": streaming,
596
+ }
597
+ if config:
598
+ dataset_kwargs["name"] = config
599
+
600
+ hf_dataset = load_dataset(dataset, **dataset_kwargs)
601
+ iterator = iter(hf_dataset)
602
+
603
+ first_row: dict[str, object] | None = None
604
+ if text_field is None:
605
+ first_row = dict(next(iterator))
606
+ iterator = chain([first_row], iterator)
607
+
608
+ output = Path(output_path)
609
+ output.parent.mkdir(parents=True, exist_ok=True)
610
+
611
+ written = 0
612
+ record_types: set[str] = set()
613
+ normalized_languages = {language.casefold() for language in allowed_languages if language.strip()}
614
+ with output.open("w", encoding="utf-8") as handle:
615
+ for row in iterator:
616
+ if written >= limit:
617
+ break
618
+ normalized_row = dict(row)
619
+ try:
620
+ records = _expand_row_records(
621
+ dataset=dataset,
622
+ config=config,
623
+ split=split,
624
+ row=normalized_row,
625
+ text_field=text_field,
626
+ preference_target=preference_target,
627
+ )
628
+ except ValueError:
629
+ continue
630
+
631
+ for record in records:
632
+ if written >= limit:
633
+ break
634
+ if not _passes_quality_gate(
635
+ record,
636
+ min_words=min_words,
637
+ max_words=max_words,
638
+ min_alpha_ratio=min_alpha_ratio,
639
+ allowed_languages=normalized_languages,
640
+ ):
641
+ continue
642
+ record_types.add(record.get("record_type", "text"))
643
+ handle.write(json.dumps(record, ensure_ascii=False) + "\n")
644
+ written += 1
645
+
646
+ inferred_mode = "mixed" if len(record_types) > 1 else (next(iter(record_types)) if record_types else "unknown")
647
+ return {
648
+ "dataset": dataset,
649
+ "config": config or "",
650
+ "split": split,
651
+ "text_field": text_field or "",
652
+ "output_path": str(output.resolve()),
653
+ "records_written": written,
654
+ "record_types": sorted(record_types),
655
+ "mode": inferred_mode,
656
+ "preference_target": preference_target,
657
+ "streaming": streaming,
658
+ "min_words": min_words,
659
+ "max_words": max_words,
660
+ "min_alpha_ratio": min_alpha_ratio,
661
+ "allowed_languages": sorted(normalized_languages),
662
+ }
reframr/hippo.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ import site
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ from .linalg import Matrix, Vector, identity, invert_matrix, matvec
8
+
9
+ _VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
10
+ for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
11
+ if _vendor_path.exists():
12
+ vendor_text = str(_vendor_path)
13
+ if vendor_text not in sys.path:
14
+ sys.path.insert(0, vendor_text)
15
+
16
+ try:
17
+ import numpy as np
18
+ except ModuleNotFoundError:
19
+ user_site = site.getusersitepackages()
20
+ if user_site and user_site not in sys.path:
21
+ sys.path.append(user_site)
22
+ try:
23
+ import numpy as np
24
+ except ModuleNotFoundError:
25
+ np = None
26
+
27
+
28
+ def hippo_legs_matrix(order: int) -> tuple[Matrix, Vector]:
29
+ a_matrix = [[0.0 for _ in range(order)] for _ in range(order)]
30
+ b_vector = [0.0 for _ in range(order)]
31
+
32
+ for row in range(order):
33
+ for col in range(order):
34
+ if row > col:
35
+ a_matrix[row][col] = -math.sqrt(2 * row + 1) * math.sqrt(2 * col + 1)
36
+ elif row == col:
37
+ a_matrix[row][col] = -(row + 1)
38
+ b_vector[row] = math.sqrt(2 * row + 1)
39
+
40
+ return a_matrix, b_vector
41
+
42
+
43
+ def analytical_embedding_drive(embedding: Vector, state_dim: int) -> Vector:
44
+ if not embedding:
45
+ return [0.0 for _ in range(state_dim)]
46
+ width = len(embedding)
47
+ return [
48
+ (
49
+ embedding[index % width]
50
+ + 0.5 * embedding[(3 * index + 1) % width]
51
+ - 0.25 * embedding[(5 * index + 2) % width]
52
+ )
53
+ for index in range(state_dim)
54
+ ]
55
+
56
+
57
+ def analytical_embedding_drive_fast(embedding: object, state_dim: int) -> object:
58
+ if np is None:
59
+ embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else list(embedding)
60
+ return analytical_embedding_drive(embedding_vector, state_dim)
61
+ embedding_array = embedding if hasattr(embedding, "shape") else np.asarray(embedding, dtype=np.float64)
62
+ if embedding_array.size == 0:
63
+ return np.zeros(state_dim, dtype=np.float64)
64
+ indices = np.arange(state_dim, dtype=np.int64)
65
+ width = int(embedding_array.shape[0])
66
+ return (
67
+ embedding_array[indices % width]
68
+ + 0.5 * embedding_array[(3 * indices + 1) % width]
69
+ - 0.25 * embedding_array[(5 * indices + 2) % width]
70
+ )
71
+
72
+
73
+ @dataclass(slots=True)
74
+ class AnalyticalMemoryUnit:
75
+ state_dim: int
76
+ timescale: float
77
+
78
+ def __post_init__(self) -> None:
79
+ a_matrix, b_vector = hippo_legs_matrix(self.state_dim)
80
+ self.transition, self.input_projection = self._discretize_transition(
81
+ a_matrix,
82
+ b_vector,
83
+ self.timescale,
84
+ )
85
+
86
+ transition: Matrix = None # type: ignore[assignment]
87
+ input_projection: Vector = None # type: ignore[assignment]
88
+ transition_array: object | None = None # type: ignore[assignment]
89
+ input_projection_array: object | None = None # type: ignore[assignment]
90
+
91
+ @staticmethod
92
+ def _discretize_transition(
93
+ a_matrix: Matrix,
94
+ b_vector: Vector,
95
+ step: float,
96
+ ) -> tuple[Matrix, Vector]:
97
+ implicit_system = [
98
+ [
99
+ identity_value - step * a_value
100
+ for identity_value, a_value in zip(identity_row, a_row)
101
+ ]
102
+ for identity_row, a_row in zip(identity(len(a_matrix)), a_matrix)
103
+ ]
104
+ transition = invert_matrix(implicit_system)
105
+ input_projection = matvec(transition, [step * value for value in b_vector])
106
+ return transition, input_projection
107
+
108
+ def step(self, state: Vector, scalar_input: float) -> Vector:
109
+ if np is not None and self.transition_array is None:
110
+ self.transition_array = np.asarray(self.transition, dtype=np.float64)
111
+ self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64)
112
+ propagated = matvec(self.transition, state)
113
+ return [
114
+ propagated[index] + self.input_projection[index] * scalar_input
115
+ for index in range(self.state_dim)
116
+ ]
117
+
118
+ def step_vector(self, state: Vector, drive: Vector) -> Vector:
119
+ propagated = matvec(self.transition, state)
120
+ return [
121
+ propagated[index] + self.input_projection[index] * drive[index]
122
+ for index in range(self.state_dim)
123
+ ]
124
+
125
+ def step_fast(self, state: object, scalar_input: float) -> object:
126
+ if np is None:
127
+ state_vector = state.tolist() if hasattr(state, "tolist") else list(state)
128
+ return self.step(state_vector, scalar_input)
129
+ if self.transition_array is None or self.input_projection_array is None:
130
+ self.transition_array = np.asarray(self.transition, dtype=np.float64)
131
+ self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64)
132
+ state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64)
133
+ return (self.transition_array @ state_array) + (self.input_projection_array * scalar_input)
134
+
135
+ def step_vector_fast(self, state: object, drive: object) -> object:
136
+ if np is None:
137
+ state_vector = state.tolist() if hasattr(state, "tolist") else list(state)
138
+ drive_vector = drive.tolist() if hasattr(drive, "tolist") else list(drive)
139
+ return self.step_vector(state_vector, drive_vector)
140
+ if self.transition_array is None or self.input_projection_array is None:
141
+ self.transition_array = np.asarray(self.transition, dtype=np.float64)
142
+ self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64)
143
+ state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64)
144
+ drive_array = drive if hasattr(drive, "shape") else np.asarray(drive, dtype=np.float64)
145
+ return (self.transition_array @ state_array) + (self.input_projection_array * drive_array)
reframr/linalg.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import site
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ _VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
7
+ for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
8
+ if _vendor_path.exists():
9
+ vendor_text = str(_vendor_path)
10
+ if vendor_text not in sys.path:
11
+ sys.path.insert(0, vendor_text)
12
+
13
+ try:
14
+ import numpy as np
15
+ except ModuleNotFoundError:
16
+ user_site = site.getusersitepackages()
17
+ if user_site and user_site not in sys.path:
18
+ sys.path.append(user_site)
19
+ try:
20
+ import numpy as np
21
+ except ModuleNotFoundError:
22
+ np = None
23
+
24
+ if np is not None and not hasattr(np, "asarray"):
25
+ np = None
26
+
27
+ Matrix = list[list[float]]
28
+ Vector = list[float]
29
+ SUMPROD = getattr(math, "sumprod", None)
30
+
31
+
32
+ def zeros(rows: int, cols: int) -> Matrix:
33
+ return [[0.0 for _ in range(cols)] for _ in range(rows)]
34
+
35
+
36
+ def zeros_vector(size: int) -> Vector:
37
+ return [0.0 for _ in range(size)]
38
+
39
+
40
+ def identity(size: int) -> Matrix:
41
+ matrix = zeros(size, size)
42
+ for index in range(size):
43
+ matrix[index][index] = 1.0
44
+ return matrix
45
+
46
+
47
+ def copy_matrix(matrix: Matrix) -> Matrix:
48
+ return [row[:] for row in matrix]
49
+
50
+
51
+ def transpose(matrix: Matrix) -> Matrix:
52
+ if not matrix:
53
+ return []
54
+ if np is not None:
55
+ return np.asarray(matrix, dtype=np.float64).T.tolist()
56
+ return [list(column) for column in zip(*matrix)]
57
+
58
+
59
+ def matvec(matrix: Matrix, vector: Vector) -> Vector:
60
+ if np is not None:
61
+ return (np.asarray(matrix, dtype=np.float64) @ np.asarray(vector, dtype=np.float64)).tolist()
62
+ if SUMPROD is not None:
63
+ return [SUMPROD(row, vector) for row in matrix]
64
+ return [sum(value * vector[idx] for idx, value in enumerate(row)) for row in matrix]
65
+
66
+
67
+ def matmul(left: Matrix, right: Matrix) -> Matrix:
68
+ if not left or not right:
69
+ return []
70
+ if np is not None:
71
+ return (np.asarray(left, dtype=np.float64) @ np.asarray(right, dtype=np.float64)).tolist()
72
+ right_t = transpose(right)
73
+ if SUMPROD is not None:
74
+ return [[SUMPROD(row, column) for column in right_t] for row in left]
75
+ return [
76
+ [sum(a * b for a, b in zip(row, column)) for column in right_t]
77
+ for row in left
78
+ ]
79
+
80
+
81
+ def add_matrices(left: Matrix, right: Matrix) -> Matrix:
82
+ return [
83
+ [left[row][col] + right[row][col] for col in range(len(left[row]))]
84
+ for row in range(len(left))
85
+ ]
86
+
87
+
88
+ def subtract_matrices(left: Matrix, right: Matrix) -> Matrix:
89
+ return [
90
+ [left[row][col] - right[row][col] for col in range(len(left[row]))]
91
+ for row in range(len(left))
92
+ ]
93
+
94
+
95
+ def scale_matrix(matrix: Matrix, scalar: float) -> Matrix:
96
+ return [[scalar * value for value in row] for row in matrix]
97
+
98
+
99
+ def dot(left: Vector, right: Vector) -> float:
100
+ if np is not None:
101
+ return float(np.dot(np.asarray(left, dtype=np.float64), np.asarray(right, dtype=np.float64)))
102
+ if SUMPROD is not None:
103
+ return SUMPROD(left, right)
104
+ return sum(a * b for a, b in zip(left, right))
105
+
106
+
107
+ def norm(vector: Vector) -> float:
108
+ return math.sqrt(dot(vector, vector))
109
+
110
+
111
+ def outer(left: Vector, right: Vector) -> Matrix:
112
+ if np is not None:
113
+ return np.outer(np.asarray(left, dtype=np.float64), np.asarray(right, dtype=np.float64)).tolist()
114
+ return [[a * b for b in right] for a in left]
115
+
116
+
117
+ def mean(values: Vector) -> float:
118
+ return sum(values) / len(values) if values else 0.0
119
+
120
+
121
+ def trace(matrix: Matrix) -> float:
122
+ return sum(matrix[index][index] for index in range(min(len(matrix), len(matrix[0]))))
123
+
124
+
125
+ def covariance_matrix(samples: list[Vector]) -> Matrix:
126
+ if not samples:
127
+ return []
128
+ if np is not None:
129
+ sample_array = np.asarray(samples, dtype=np.float64)
130
+ centered = sample_array - sample_array.mean(axis=0, keepdims=True)
131
+ denominator = max(len(samples) - 1, 1)
132
+ return ((centered.T @ centered) / denominator).tolist()
133
+
134
+ feature_count = len(samples[0])
135
+ sample_count = len(samples)
136
+ means = [
137
+ sum(sample[feature] for sample in samples) / sample_count
138
+ for feature in range(feature_count)
139
+ ]
140
+ covariance = zeros(feature_count, feature_count)
141
+ for sample in samples:
142
+ centered = [sample[index] - means[index] for index in range(feature_count)]
143
+ for row in range(feature_count):
144
+ for col in range(feature_count):
145
+ covariance[row][col] += centered[row] * centered[col]
146
+
147
+ denominator = max(sample_count - 1, 1)
148
+ return scale_matrix(covariance, 1.0 / denominator)
149
+
150
+
151
+ def solve_linear_system(matrix: Matrix, vector: Vector) -> Vector:
152
+ if np is not None:
153
+ return np.linalg.solve(
154
+ np.asarray(matrix, dtype=np.float64),
155
+ np.asarray(vector, dtype=np.float64),
156
+ ).tolist()
157
+ size = len(matrix)
158
+ augmented = [matrix[row][:] + [vector[row]] for row in range(size)]
159
+
160
+ for pivot_index in range(size):
161
+ pivot_row = max(
162
+ range(pivot_index, size),
163
+ key=lambda row_index: abs(augmented[row_index][pivot_index]),
164
+ )
165
+ augmented[pivot_index], augmented[pivot_row] = augmented[pivot_row], augmented[pivot_index]
166
+
167
+ pivot_value = augmented[pivot_index][pivot_index]
168
+ if abs(pivot_value) < 1e-12:
169
+ raise ValueError("Singular matrix encountered while solving linear system.")
170
+
171
+ inverse_pivot = 1.0 / pivot_value
172
+ augmented[pivot_index] = [value * inverse_pivot for value in augmented[pivot_index]]
173
+
174
+ for row_index in range(size):
175
+ if row_index == pivot_index:
176
+ continue
177
+ factor = augmented[row_index][pivot_index]
178
+ augmented[row_index] = [
179
+ augmented[row_index][col] - factor * augmented[pivot_index][col]
180
+ for col in range(size + 1)
181
+ ]
182
+
183
+ return [augmented[row][-1] for row in range(size)]
184
+
185
+
186
+ def invert_matrix(matrix: Matrix) -> Matrix:
187
+ if np is not None:
188
+ return np.linalg.inv(np.asarray(matrix, dtype=np.float64)).tolist()
189
+ size = len(matrix)
190
+ inverse_columns = []
191
+ for basis_index in range(size):
192
+ basis_vector = [0.0 for _ in range(size)]
193
+ basis_vector[basis_index] = 1.0
194
+ inverse_columns.append(solve_linear_system(matrix, basis_vector))
195
+ return transpose(inverse_columns)
196
+
197
+
198
+ def dominant_eigenpair_symmetric(
199
+ matrix: Matrix,
200
+ max_iterations: int = 64,
201
+ tolerance: float = 1e-10,
202
+ ) -> tuple[float, Vector]:
203
+ size = len(matrix)
204
+ if size == 0:
205
+ return 0.0, []
206
+ if np is not None:
207
+ values, vectors = np.linalg.eigh(np.asarray(matrix, dtype=np.float64))
208
+ index = int(np.argmax(values))
209
+ eigenvalue = float(values[index])
210
+ if eigenvalue <= tolerance:
211
+ return 0.0, zeros_vector(size)
212
+ return eigenvalue, vectors[:, index].astype(float).tolist()
213
+
214
+ vector = [1.0 / math.sqrt(size) for _ in range(size)]
215
+ for _ in range(max_iterations):
216
+ next_vector = matvec(matrix, vector)
217
+ next_norm = norm(next_vector)
218
+ if next_norm < tolerance:
219
+ return 0.0, zeros_vector(size)
220
+
221
+ next_vector = [value / next_norm for value in next_vector]
222
+ delta = max(abs(a - b) for a, b in zip(vector, next_vector))
223
+ vector = next_vector
224
+ if delta < tolerance:
225
+ break
226
+
227
+ eigenvalue = dot(vector, matvec(matrix, vector))
228
+ return eigenvalue, vector
229
+
230
+
231
+ def top_k_eigenpairs_symmetric(matrix: Matrix, k: int) -> list[tuple[float, Vector]]:
232
+ if np is not None and matrix:
233
+ values, vectors = np.linalg.eigh(np.asarray(matrix, dtype=np.float64))
234
+ ranked = sorted(
235
+ (
236
+ (float(values[index]), vectors[:, index].astype(float).tolist())
237
+ for index in range(len(values))
238
+ if float(values[index]) > 1e-9
239
+ ),
240
+ key=lambda item: item[0],
241
+ reverse=True,
242
+ )
243
+ return ranked[: min(k, len(ranked))]
244
+ working = copy_matrix(matrix)
245
+ eigenpairs: list[tuple[float, Vector]] = []
246
+ for _ in range(min(k, len(working))):
247
+ eigenvalue, eigenvector = dominant_eigenpair_symmetric(working)
248
+ if eigenvalue <= 1e-9 or not eigenvector:
249
+ break
250
+ eigenpairs.append((eigenvalue, eigenvector))
251
+ deflation = scale_matrix(outer(eigenvector, eigenvector), eigenvalue)
252
+ working = subtract_matrices(working, deflation)
253
+ return eigenpairs
254
+
255
+
256
+ def softmax(logits: Vector) -> Vector:
257
+ if not logits:
258
+ return []
259
+ if np is not None:
260
+ values = np.asarray(logits, dtype=np.float64)
261
+ shifted = np.exp(values - values.max())
262
+ total = float(shifted.sum())
263
+ if total == 0.0:
264
+ return [1.0 / len(logits) for _ in logits]
265
+ return (shifted / total).tolist()
266
+ max_logit = max(logits)
267
+ shifted = [math.exp(logit - max_logit) for logit in logits]
268
+ total = sum(shifted)
269
+ if total == 0.0:
270
+ return [1.0 / len(logits) for _ in logits]
271
+ return [value / total for value in shifted]
reframr/model.py ADDED
The diff for this file is too large to render. See raw diff
 
reframr/reasoning.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TOKENIZER_NAME = "FrameToken"
2
+
3
+ REASONING_CONTROL_TOKENS: tuple[str, ...] = (
4
+ "<reason>",
5
+ "<plan>",
6
+ "<reflect>",
7
+ "<answer>",
8
+ "<memory>",
9
+ "<retrieve>",
10
+ "<focus>",
11
+ "<verify>",
12
+ "<tool>",
13
+ )
14
+
15
+ REASONING_PROFILES: dict[str, tuple[str, ...]] = {
16
+ "none": (),
17
+ "deep": ("<reason>",),
18
+ "memory": ("<memory>", "<retrieve>", "<focus>"),
19
+ "tool": ("<tool>", "<reason>", "<verify>"),
20
+ }
21
+
22
+
23
+ def reasoning_prefix(mode: str) -> list[str]:
24
+ if mode not in REASONING_PROFILES:
25
+ raise ValueError(f"Unknown reasoning mode: {mode}")
26
+ return list(REASONING_PROFILES[mode])
reframr/reservoir.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .linalg import Matrix, Vector, identity, invert_matrix, matmul, matvec, np, scale_matrix, transpose
2
+
3
+
4
+ def _empty_matrix(matrix: Matrix) -> bool:
5
+ if np is not None and hasattr(matrix, "size"):
6
+ return int(matrix.size) == 0
7
+ return not matrix
8
+
9
+
10
+ def ridge_regression_readout(
11
+ states: list[Vector],
12
+ targets: list[Vector],
13
+ *,
14
+ regularization: float,
15
+ ) -> Matrix:
16
+ if not states or not targets:
17
+ raise ValueError("States and targets must be non-empty for ridge readout.")
18
+ if np is not None:
19
+ state_matrix = np.asarray(states, dtype=np.float64).T
20
+ target_matrix = np.asarray(targets, dtype=np.float64).T
21
+ gram = state_matrix @ state_matrix.T
22
+ regularized = gram + (regularization * np.eye(gram.shape[0], dtype=np.float64))
23
+ cross_covariance = target_matrix @ state_matrix.T
24
+ return np.linalg.solve(regularized.T, cross_covariance.T).T.tolist()
25
+
26
+ state_matrix = transpose(states)
27
+ target_matrix = transpose(targets)
28
+ gram = matmul(state_matrix, transpose(state_matrix))
29
+ regularized = [
30
+ [
31
+ gram[row][col] + (regularization if row == col else 0.0)
32
+ for col in range(len(gram[row]))
33
+ ]
34
+ for row in range(len(gram))
35
+ ]
36
+ inverse = invert_matrix(regularized)
37
+ cross_covariance = matmul(target_matrix, transpose(state_matrix))
38
+ return matmul(cross_covariance, inverse)
39
+
40
+
41
+ def ridge_regression_readout_from_moments(
42
+ gram: Matrix,
43
+ cross_covariance: Matrix,
44
+ *,
45
+ regularization: float,
46
+ ) -> Matrix:
47
+ if _empty_matrix(gram) or _empty_matrix(cross_covariance):
48
+ raise ValueError("Gram and cross-covariance moments must be non-empty for ridge readout.")
49
+ if np is not None:
50
+ gram_array = np.asarray(gram, dtype=np.float64)
51
+ regularized = gram_array + (regularization * np.eye(gram_array.shape[0], dtype=np.float64))
52
+ cross_covariance_array = np.asarray(cross_covariance, dtype=np.float64)
53
+ return np.linalg.solve(regularized.T, cross_covariance_array.T).T
54
+
55
+ regularized = [
56
+ [
57
+ gram[row][col] + (regularization if row == col else 0.0)
58
+ for col in range(len(gram[row]))
59
+ ]
60
+ for row in range(len(gram))
61
+ ]
62
+ inverse = invert_matrix(regularized)
63
+ return matmul(cross_covariance, inverse)
64
+
65
+
66
+ def ridge_regression_readout_from_diagonal_moments(
67
+ feature_second_moment: Vector,
68
+ cross_covariance: Matrix,
69
+ *,
70
+ regularization: float,
71
+ ) -> Matrix:
72
+ if _empty_matrix(feature_second_moment) or _empty_matrix(cross_covariance):
73
+ raise ValueError("Diagonal moments and cross-covariance must be non-empty for ridge readout.")
74
+ if np is not None:
75
+ denominator = np.asarray(feature_second_moment, dtype=np.float64) + regularization
76
+ denominator = np.where(np.abs(denominator) > 1e-12, denominator, regularization)
77
+ cross_covariance_array = np.asarray(cross_covariance, dtype=np.float64)
78
+ return cross_covariance_array / denominator[None, :]
79
+
80
+ denominator = [
81
+ value + regularization if abs(value + regularization) > 1e-12 else regularization
82
+ for value in feature_second_moment
83
+ ]
84
+ return [
85
+ [
86
+ value / denominator[col]
87
+ for col, value in enumerate(row)
88
+ ]
89
+ for row in cross_covariance
90
+ ]
91
+
92
+
93
+ def apply_readout(weights: Matrix, state: Vector) -> Vector:
94
+ return matvec(weights, state)
reframr/streaming.py ADDED
@@ -0,0 +1,1852 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import random
5
+ import re
6
+ import site
7
+ import sys
8
+ import time
9
+ from collections import Counter
10
+ from collections.abc import Iterable, Iterator
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+
14
+ from .config import ReframrConfig
15
+ from .corpus import build_vocabulary_from_counts
16
+ from .embeddings import fit_ppmi_embedding_from_cooccurrence, fit_randomized_ppmi_embedding_from_counts
17
+ from .hippo import AnalyticalMemoryUnit
18
+ from .linalg import Matrix, Vector, norm, zeros, zeros_vector
19
+ from .model import ReframrModel, RUNTIME_ARRAY_DTYPE, TRANSITION_ORDERS, np
20
+ from .reservoir import (
21
+ ridge_regression_readout_from_diagonal_moments,
22
+ ridge_regression_readout_from_moments,
23
+ )
24
+ from .ternary import apply_ternary_mask, derive_ternary_mask_from_feature_energy
25
+ from .text_quality import clean_answer_text, clean_context_text, clean_training_text
26
+ from .tokenizer import NativeTokenizer
27
+
28
+ try:
29
+ from scipy import sparse as scipy_sparse
30
+ except (ImportError, ModuleNotFoundError, OSError):
31
+ scipy_sparse = None
32
+
33
+ TEXT_FIELD_PREFERENCES = (
34
+ "text",
35
+ "content",
36
+ "body",
37
+ "article",
38
+ "document",
39
+ "passage",
40
+ "markdown",
41
+ "answer",
42
+ "response",
43
+ )
44
+
45
+ DIALOGUE_FIELD_PREFERENCES = (
46
+ "messages",
47
+ "conversation",
48
+ "conversations",
49
+ "dialogue",
50
+ "dialog",
51
+ "turns",
52
+ "chosen",
53
+ )
54
+ INSTRUCTION_FIELD_PAIRS = (
55
+ ("instruction", "output"),
56
+ ("prompt", "completion"),
57
+ ("prompt", "response"),
58
+ ("question", "answer"),
59
+ ("question", "response"),
60
+ ("query", "answer"),
61
+ ("query", "response"),
62
+ )
63
+ TRANSCRIPT_ROLE_PATTERN = re.compile(r"(?:^|\n\s*\n)(Human|Assistant|System)\s*:\s*", re.IGNORECASE)
64
+ ROLE_ALIASES = {
65
+ "assistant": "assistant",
66
+ "assistant_response": "assistant",
67
+ "bot": "assistant",
68
+ "gpt": "assistant",
69
+ "model": "assistant",
70
+ "human": "user",
71
+ "prompter": "user",
72
+ "user": "user",
73
+ "customer": "user",
74
+ "system": "system",
75
+ }
76
+ ANSWER_READOUT_WEIGHT = 1.0
77
+ CONTEXT_READOUT_WEIGHT = 0.0
78
+ CONTEXT_STAT_WEIGHT = 0.02
79
+ PLAIN_TEXT_READOUT_WEIGHT = 0.03
80
+ PREFERENCE_REJECTED_TOKENIZER_WEIGHT = 0.0
81
+ PREFERENCE_BIAS_SCALE = 0.95
82
+ MAX_PREFERENCE_STATE_PAIRS = 512
83
+ ANSWER_START_TOKEN_WINDOW = 12
84
+ ANSWER_START_DECAY = 0.86
85
+ MAX_ANSWER_SEQUENCE_EXAMPLES = 196608
86
+ MAX_ANSWER_SEQUENCE_TOKENS = 192
87
+ HF_STREAM_MAX_RETRIES = 5
88
+ HF_STREAM_RETRY_BASE_DELAY_SECONDS = 0.25
89
+ FULL_READOUT_FEATURE_LIMIT = 2304
90
+ FULL_READOUT_EXAMPLE_LIMIT = 25000
91
+
92
+
93
+ @dataclass(slots=True)
94
+ class CorpusPlanEntry:
95
+ source: str
96
+ name: str
97
+ dataset: str = ""
98
+ path: str = ""
99
+ config: str | None = None
100
+ split: str = "train"
101
+ limit: int = 0
102
+ weight: float = 1.0
103
+ text_field: str | None = None
104
+ min_words: int = 0
105
+ max_words: int = 0
106
+ min_alpha_ratio: float = 0.0
107
+ allowed_languages: tuple[str, ...] = ()
108
+ records: tuple[object, ...] = ()
109
+ streaming: bool = True
110
+ trust_remote_code: bool = False
111
+
112
+
113
+ @dataclass(slots=True)
114
+ class StreamDocument:
115
+ text: str
116
+ weight: float
117
+ source: str
118
+ language: str = ""
119
+ preference_rejected_text: str = ""
120
+
121
+
122
+ class StreamingCooccurrenceAccumulator:
123
+ def __init__(self, token_to_id: dict[str, int], window_size: int) -> None:
124
+ self.token_to_id = token_to_id
125
+ self.window_size = window_size
126
+ self.rows: dict[int, dict[int, float]] = {}
127
+
128
+ def update_tokens(self, tokens: list[str], *, weight: float) -> None:
129
+ token_ids = [self.token_to_id[token] for token in tokens if token in self.token_to_id]
130
+ for index, token_id in enumerate(token_ids):
131
+ for offset in range(1, self.window_size + 1):
132
+ other_index = index + offset
133
+ if other_index >= len(token_ids):
134
+ break
135
+ other_id = token_ids[other_index]
136
+ delta = weight * (1.0 / offset)
137
+ self.rows.setdefault(token_id, {})[other_id] = (
138
+ self.rows.setdefault(token_id, {}).get(other_id, 0.0) + delta
139
+ )
140
+ self.rows.setdefault(other_id, {})[token_id] = (
141
+ self.rows.setdefault(other_id, {}).get(token_id, 0.0) + delta
142
+ )
143
+
144
+ def to_dense(self) -> Matrix:
145
+ size = len(self.token_to_id)
146
+ matrix = zeros(size, size)
147
+ for row, columns in self.rows.items():
148
+ for col, value in columns.items():
149
+ matrix[row][col] = value
150
+ return matrix
151
+
152
+ def to_sparse(self) -> object:
153
+ if scipy_sparse is None or np is None:
154
+ return self.to_dense()
155
+ rows: list[int] = []
156
+ cols: list[int] = []
157
+ data: list[float] = []
158
+ for row, columns in self.rows.items():
159
+ for col, value in columns.items():
160
+ rows.append(row)
161
+ cols.append(col)
162
+ data.append(value)
163
+ size = len(self.token_to_id)
164
+ return scipy_sparse.coo_matrix(
165
+ (
166
+ np.asarray(data, dtype=np.float64),
167
+ (np.asarray(rows, dtype=np.int64), np.asarray(cols, dtype=np.int64)),
168
+ ),
169
+ shape=(size, size),
170
+ dtype=np.float64,
171
+ ).tocsr()
172
+
173
+
174
+ class TransitionAccumulator:
175
+ def __init__(
176
+ self,
177
+ *,
178
+ max_contexts_per_order: int | None = None,
179
+ max_next_tokens: int = 0,
180
+ ) -> None:
181
+ self.max_contexts_per_order = max_contexts_per_order
182
+ self.max_next_tokens = max_next_tokens
183
+ self.context_soft_limit = (
184
+ max_contexts_per_order * 4
185
+ if max_contexts_per_order is not None and max_contexts_per_order > 0
186
+ else None
187
+ )
188
+ self.next_token_soft_limit = max_next_tokens * 4 if max_next_tokens > 0 else None
189
+ self.counts: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
190
+ order: {} for order in sorted(TRANSITION_ORDERS)
191
+ }
192
+
193
+ def update_tokens(self, tokens: list[str], *, weight: float) -> None:
194
+ for order in sorted(TRANSITION_ORDERS):
195
+ order_counts = self.counts[order]
196
+ for index in range(order - 1, len(tokens) - 1):
197
+ key = tuple(tokens[index - order + 1 : index + 1])
198
+ nxt = tokens[index + 1]
199
+ if (
200
+ self.context_soft_limit is not None
201
+ and key not in order_counts
202
+ and len(order_counts) >= self.context_soft_limit
203
+ ):
204
+ continue
205
+ bucket = order_counts.setdefault(key, {})
206
+ if (
207
+ self.next_token_soft_limit is not None
208
+ and nxt not in bucket
209
+ and len(bucket) >= self.next_token_soft_limit
210
+ ):
211
+ continue
212
+ bucket[nxt] = bucket.get(nxt, 0.0) + weight
213
+
214
+ def finalize(
215
+ self,
216
+ *,
217
+ max_contexts_per_order: int | None,
218
+ max_next_tokens: int,
219
+ ) -> dict[int, dict[tuple[str, ...], dict[str, float]]]:
220
+ probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
221
+ order: {} for order in sorted(TRANSITION_ORDERS)
222
+ }
223
+ for order, mapping in self.counts.items():
224
+ items = list(mapping.items())
225
+ items.sort(key=lambda item: (-sum(item[1].values()), item[0]))
226
+ if max_contexts_per_order is not None and max_contexts_per_order >= 0:
227
+ items = items[:max_contexts_per_order]
228
+ for key, bucket in items:
229
+ next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0]))
230
+ if max_next_tokens > 0:
231
+ next_items = next_items[:max_next_tokens]
232
+ total = sum(value for _, value in next_items)
233
+ if total <= 0.0:
234
+ continue
235
+ probabilities[order][key] = {
236
+ token: value / total
237
+ for token, value in next_items
238
+ }
239
+ return probabilities
240
+
241
+
242
+ class StateReservoir:
243
+ def __init__(self, capacity: int | None, *, seed: int = 13) -> None:
244
+ self.capacity = capacity
245
+ self.random = random.Random(seed)
246
+ self.states: list[Vector] = []
247
+ self.labels: list[int] = []
248
+ self.weights: list[float] = []
249
+ self.seen = 0
250
+ self.total_weight = 0.0
251
+
252
+ def reserve_slot(self, weight: float = 1.0) -> int | None:
253
+ if weight <= 0.0:
254
+ return None
255
+ self.seen += 1
256
+ self.total_weight += weight
257
+ if self.capacity is None:
258
+ return len(self.states)
259
+ if self.capacity <= 0:
260
+ return None
261
+ if len(self.states) < self.capacity:
262
+ return len(self.states)
263
+ keep_probability = min(1.0, (self.capacity * weight) / max(self.total_weight, 1e-12))
264
+ if self.random.random() >= keep_probability:
265
+ return None
266
+ return self.random.randrange(self.capacity)
267
+
268
+ def store_reserved(
269
+ self,
270
+ slot: int,
271
+ state: Vector,
272
+ label_id: int,
273
+ *,
274
+ example_weight: float = 1.0,
275
+ ) -> None:
276
+ stored_state = state.copy() if hasattr(state, "copy") else state[:]
277
+ if slot == len(self.states):
278
+ self.states.append(stored_state)
279
+ self.labels.append(label_id)
280
+ self.weights.append(example_weight)
281
+ elif 0 <= slot < len(self.states):
282
+ self.states[slot] = stored_state
283
+ self.labels[slot] = label_id
284
+ self.weights[slot] = example_weight
285
+
286
+ def consider(self, state: Vector, label_id: int, weight: float = 1.0) -> None:
287
+ slot = self.reserve_slot(weight=weight)
288
+ if slot is not None:
289
+ self.store_reserved(slot, state, label_id, example_weight=weight)
290
+
291
+
292
+ class SequenceReservoir:
293
+ def __init__(self, capacity: int | None, *, seed: int = 41) -> None:
294
+ self.capacity = capacity
295
+ self.random = random.Random(seed)
296
+ self.keys: list[Vector] = []
297
+ self.prompt_rows: list[list[int]] = []
298
+ self.token_rows: list[list[int]] = []
299
+ self.weights: list[float] = []
300
+ self.seen_weight = 0.0
301
+
302
+ def reserve_slot(self, *, weight: float = 1.0) -> int | None:
303
+ if self.capacity == 0 or weight <= 0.0:
304
+ return None
305
+ self.seen_weight += weight
306
+ if self.capacity is None or len(self.keys) < self.capacity:
307
+ return len(self.keys)
308
+ probability = min(1.0, (self.capacity * weight) / max(self.seen_weight, 1e-12))
309
+ if self.random.random() >= probability:
310
+ return None
311
+ return self.random.randrange(self.capacity)
312
+
313
+ def store_reserved(
314
+ self,
315
+ slot: int,
316
+ key: Vector,
317
+ prompt_token_ids: list[int],
318
+ token_ids: list[int],
319
+ *,
320
+ example_weight: float = 1.0,
321
+ ) -> None:
322
+ key_copy = key.tolist() if hasattr(key, "tolist") else list(key)
323
+ prompt_row = prompt_token_ids[:MAX_ANSWER_SEQUENCE_TOKENS]
324
+ row = token_ids[:MAX_ANSWER_SEQUENCE_TOKENS]
325
+ if self.capacity is None or slot >= len(self.keys):
326
+ self.keys.append(key_copy)
327
+ self.prompt_rows.append(prompt_row)
328
+ self.token_rows.append(row)
329
+ self.weights.append(example_weight)
330
+ return
331
+ self.keys[slot] = key_copy
332
+ self.prompt_rows[slot] = prompt_row
333
+ self.token_rows[slot] = row
334
+ self.weights[slot] = example_weight
335
+
336
+ def consider(
337
+ self,
338
+ key: Vector,
339
+ prompt_token_ids: list[int],
340
+ token_ids: list[int],
341
+ weight: float = 1.0,
342
+ ) -> None:
343
+ if not token_ids:
344
+ return
345
+ slot = self.reserve_slot(weight=weight)
346
+ if slot is not None:
347
+ self.store_reserved(slot, key, prompt_token_ids, token_ids, example_weight=weight)
348
+
349
+
350
+ def _word_count(text: str) -> int:
351
+ return len(text.split())
352
+
353
+
354
+ def _alpha_ratio(text: str) -> float:
355
+ if not text:
356
+ return 0.0
357
+ alpha_count = sum(character.isalpha() for character in text)
358
+ return alpha_count / len(text)
359
+
360
+
361
+ def _row_language(row: dict[str, object]) -> str:
362
+ for candidate in ("lang", "language", "locale"):
363
+ value = row.get(candidate)
364
+ if isinstance(value, str) and value.strip():
365
+ return value.strip()
366
+ return ""
367
+
368
+
369
+ def _normalize_role(raw_role: object) -> str:
370
+ role = str(raw_role or "").strip().casefold()
371
+ return ROLE_ALIASES.get(role, role)
372
+
373
+
374
+ def _message_content(message: dict[str, object]) -> str:
375
+ for field in ("content", "value", "text", "message"):
376
+ value = message.get(field)
377
+ if isinstance(value, str) and value.strip():
378
+ return clean_training_text(value)
379
+ return ""
380
+
381
+
382
+ def _message_role(message: dict[str, object]) -> str:
383
+ for field in ("role", "from", "speaker", "author"):
384
+ value = message.get(field)
385
+ if value is not None:
386
+ normalized = _normalize_role(value)
387
+ if normalized:
388
+ return normalized
389
+ return ""
390
+
391
+
392
+ def _parse_dialogue_messages(raw_messages: object) -> list[dict[str, str]]:
393
+ if not isinstance(raw_messages, list):
394
+ return []
395
+
396
+ parsed: list[dict[str, str]] = []
397
+ for message in raw_messages:
398
+ if not isinstance(message, dict):
399
+ continue
400
+ role = _message_role(message)
401
+ content = _message_content(message)
402
+ if role not in {"system", "user", "assistant"} or not content:
403
+ continue
404
+ parsed.append({"role": role, "content": content})
405
+ return parsed
406
+
407
+
408
+ def _parse_transcript_messages(raw_text: object) -> list[dict[str, str]]:
409
+ if not isinstance(raw_text, str):
410
+ return []
411
+
412
+ text = raw_text.strip()
413
+ if not text:
414
+ return []
415
+
416
+ matches = list(TRANSCRIPT_ROLE_PATTERN.finditer(text))
417
+ if not matches:
418
+ return []
419
+
420
+ parsed: list[dict[str, str]] = []
421
+ for index, match in enumerate(matches):
422
+ role = _normalize_role(match.group(1))
423
+ start = match.end()
424
+ end = matches[index + 1].start() if index + 1 < len(matches) else len(text)
425
+ content = clean_training_text(text[start:end].strip())
426
+ if role in {"system", "user", "assistant"} and content:
427
+ parsed.append({"role": role, "content": content})
428
+ return parsed
429
+
430
+
431
+ def _render_prompt(messages: list[dict[str, str]]) -> str:
432
+ parts = []
433
+ for message in messages:
434
+ content = clean_context_text(message["content"])
435
+ if content:
436
+ parts.append(content)
437
+ return "\n".join(parts).strip()
438
+
439
+
440
+ def _last_user_prompt_before(messages: list[dict[str, str]], end_index: int) -> str:
441
+ for message in reversed(messages[:end_index]):
442
+ if message["role"] == "user":
443
+ return clean_context_text(message["content"])
444
+ return _render_prompt(messages[:end_index])
445
+
446
+
447
+ def _compose_training_text(context: object, answer: object) -> str:
448
+ prompt_text = clean_context_text(_flatten_value(context))
449
+ answer_text = clean_answer_text(_flatten_value(answer))
450
+ if prompt_text and answer_text:
451
+ return f"<reason> {prompt_text} <answer> {answer_text}".strip()
452
+ return clean_training_text(answer_text or prompt_text)
453
+
454
+
455
+ def _compose_from_messages(messages: list[dict[str, str]]) -> str:
456
+ assistant_index = None
457
+ for index in range(len(messages) - 1, -1, -1):
458
+ if messages[index]["role"] == "assistant":
459
+ assistant_index = index
460
+ break
461
+ if assistant_index is not None:
462
+ prompt = _last_user_prompt_before(messages, assistant_index)
463
+ answer = clean_answer_text(messages[assistant_index]["content"])
464
+ if prompt and answer:
465
+ return f"<reason> {prompt} <answer> {answer}".strip()
466
+ return "\n".join(
467
+ message["content"]
468
+ for message in messages
469
+ if message.get("content")
470
+ ).strip()
471
+
472
+
473
+ def _flatten_message_list(messages: object) -> str:
474
+ parsed = _parse_dialogue_messages(messages)
475
+ if parsed:
476
+ return _compose_from_messages(parsed)
477
+ if not isinstance(messages, list):
478
+ return ""
479
+ parts: list[str] = []
480
+ for message in messages:
481
+ if not isinstance(message, dict):
482
+ continue
483
+ content = str(
484
+ message.get("content", message.get("value", message.get("text", "")))
485
+ ).strip()
486
+ if not content:
487
+ continue
488
+ parts.append(clean_training_text(content))
489
+ return "\n".join(parts).strip()
490
+
491
+
492
+ def _flatten_value(value: object) -> str:
493
+ if isinstance(value, str):
494
+ parsed = _parse_transcript_messages(value)
495
+ if parsed:
496
+ return _compose_from_messages(parsed)
497
+ return clean_training_text(value.strip())
498
+ if isinstance(value, list):
499
+ return _flatten_message_list(value)
500
+ if isinstance(value, dict):
501
+ for field in ("messages", "conversation", "conversations", "dialogue", "turns"):
502
+ nested_messages = value.get(field)
503
+ text = _flatten_message_list(nested_messages)
504
+ if text:
505
+ return text
506
+ for field in ("text", "content", "value", "message"):
507
+ nested = value.get(field)
508
+ if isinstance(nested, str) and nested.strip():
509
+ return _flatten_value(nested)
510
+ return ""
511
+
512
+
513
+ def _safe_flag(value: object) -> bool | None:
514
+ if isinstance(value, bool):
515
+ return value
516
+ if isinstance(value, str):
517
+ normalized = value.strip().casefold()
518
+ if normalized in {"true", "1", "yes", "safe"}:
519
+ return True
520
+ if normalized in {"false", "0", "no", "unsafe"}:
521
+ return False
522
+ return None
523
+
524
+
525
+ def _selected_response_fields(row: dict[str, object]) -> tuple[str, str]:
526
+ if "response_0" not in row or "response_1" not in row:
527
+ return "", ""
528
+ safe_0 = _safe_flag(row.get("is_response_0_safe"))
529
+ safe_1 = _safe_flag(row.get("is_response_1_safe"))
530
+ if safe_0 is not None and safe_1 is not None:
531
+ if safe_0 and not safe_1:
532
+ return "response_0", "response_1"
533
+ if safe_1 and not safe_0:
534
+ return "response_1", "response_0"
535
+ if safe_0 and safe_1:
536
+ return "response_0", ""
537
+ return "", ""
538
+ for selector in ("safer_response_id", "better_response_id"):
539
+ raw_value = row.get(selector)
540
+ try:
541
+ preferred = int(raw_value)
542
+ except (TypeError, ValueError):
543
+ continue
544
+ chosen = "response_1" if preferred == 1 else "response_0"
545
+ rejected = "response_0" if chosen == "response_1" else "response_1"
546
+ return chosen, rejected
547
+ return "response_0", "response_1"
548
+
549
+
550
+ def _extract_preference_pair(row: dict[str, object]) -> tuple[str, str]:
551
+ if "chosen" in row and "rejected" in row:
552
+ chosen_text = clean_training_text(_flatten_value(row.get("chosen")))
553
+ rejected_text = clean_training_text(_flatten_value(row.get("rejected")))
554
+ if chosen_text and rejected_text:
555
+ return chosen_text, rejected_text
556
+ if "response_0" in row and "response_1" in row:
557
+ preferred_field, rejected_field = _selected_response_fields(row)
558
+ if not preferred_field or not rejected_field:
559
+ return "", ""
560
+ prompt = row.get("prompt", row.get("question", row.get("query", "")))
561
+ if prompt:
562
+ chosen_text = _compose_training_text(prompt, row.get(preferred_field))
563
+ rejected_text = _compose_training_text(prompt, row.get(rejected_field))
564
+ if chosen_text and rejected_text:
565
+ return clean_training_text(chosen_text), clean_training_text(rejected_text)
566
+ chosen_text = clean_training_text(_flatten_value(row.get(preferred_field)))
567
+ rejected_text = clean_training_text(_flatten_value(row.get(rejected_field)))
568
+ if chosen_text and rejected_text:
569
+ return chosen_text, rejected_text
570
+ return "", ""
571
+
572
+
573
+ def _extract_preference_value(row: dict[str, object]) -> str:
574
+ chosen_text, _ = _extract_preference_pair(row)
575
+ return chosen_text
576
+
577
+
578
+ def _extract_row_text(row: dict[str, object], text_field: str | None) -> str:
579
+ if "context" in row and "answer" in row:
580
+ context = clean_context_text(_flatten_value(row.get("context")))
581
+ answer = clean_answer_text(_flatten_value(row.get("answer")))
582
+ if context and answer:
583
+ return f"<reason> {context} <answer> {answer}".strip()
584
+
585
+ if "response_0" in row and "response_1" in row:
586
+ preferred_field, _ = _selected_response_fields(row)
587
+ prompt = row.get("prompt", row.get("question", row.get("query", "")))
588
+ if preferred_field and prompt:
589
+ text = _compose_training_text(prompt, row.get(preferred_field))
590
+ if text:
591
+ return text
592
+
593
+ for prompt_field, answer_field in INSTRUCTION_FIELD_PAIRS:
594
+ if prompt_field in row and answer_field in row:
595
+ text = _compose_training_text(row.get(prompt_field), row.get(answer_field))
596
+ if text:
597
+ return text
598
+
599
+ if text_field is not None:
600
+ return clean_training_text(_flatten_value(row.get(text_field)))
601
+
602
+ preferred = _extract_preference_value(row)
603
+ if preferred:
604
+ return clean_training_text(preferred)
605
+
606
+ for field in TEXT_FIELD_PREFERENCES:
607
+ text = _flatten_value(row.get(field))
608
+ if text:
609
+ return clean_training_text(text)
610
+ for field in DIALOGUE_FIELD_PREFERENCES:
611
+ text = _flatten_value(row.get(field))
612
+ if text:
613
+ return clean_training_text(text)
614
+ return ""
615
+
616
+
617
+ def _passes_text_quality(text: str, language: str, entry: CorpusPlanEntry) -> bool:
618
+ if not text:
619
+ return False
620
+ word_count = _word_count(text)
621
+ if entry.min_words > 0 and word_count < entry.min_words:
622
+ return False
623
+ if entry.max_words > 0 and word_count > entry.max_words:
624
+ return False
625
+ if entry.min_alpha_ratio > 0.0 and _alpha_ratio(text) < entry.min_alpha_ratio:
626
+ return False
627
+ if entry.allowed_languages:
628
+ if not language or language.casefold() not in entry.allowed_languages:
629
+ return False
630
+ return True
631
+
632
+
633
+ def load_corpus_plan(source: str | Path) -> list[CorpusPlanEntry]:
634
+ payload = json.loads(Path(source).read_text(encoding="utf-8-sig"))
635
+ raw_entries = payload.get("sources", payload.get("datasets", []))
636
+ if not isinstance(raw_entries, list) or not raw_entries:
637
+ raise ValueError("Corpus plan must define a non-empty 'sources' list.")
638
+
639
+ entries: list[CorpusPlanEntry] = []
640
+ for index, raw_entry in enumerate(raw_entries, start=1):
641
+ if not isinstance(raw_entry, dict):
642
+ raise ValueError("Each corpus plan entry must be an object.")
643
+ source = str(raw_entry.get("source", "hf")).strip() or "hf"
644
+ name = str(
645
+ raw_entry.get("name", raw_entry.get("dataset", f"source-{index}"))
646
+ ).strip() or f"source-{index}"
647
+ raw_languages = raw_entry.get("allowed_languages", [])
648
+ allowed_languages = tuple(
649
+ str(value).strip().casefold()
650
+ for value in raw_languages
651
+ if str(value).strip()
652
+ ) if isinstance(raw_languages, list) else ()
653
+ raw_records = raw_entry.get("records", raw_entry.get("texts", []))
654
+ if source == "inline" and not isinstance(raw_records, list):
655
+ raise ValueError("Inline corpus plan entries must provide a records/texts list.")
656
+ entries.append(
657
+ CorpusPlanEntry(
658
+ source=source,
659
+ name=name,
660
+ dataset=str(raw_entry.get("dataset", "")),
661
+ path=str(raw_entry.get("path", raw_entry.get("file", ""))),
662
+ config=(
663
+ str(raw_entry["config"])
664
+ if raw_entry.get("config") is not None
665
+ else None
666
+ ),
667
+ split=str(raw_entry.get("split", "train")),
668
+ limit=int(raw_entry.get("limit", 0)),
669
+ weight=float(raw_entry.get("weight", 1.0)),
670
+ text_field=(
671
+ str(raw_entry["text_field"])
672
+ if raw_entry.get("text_field") is not None
673
+ else None
674
+ ),
675
+ min_words=int(raw_entry.get("min_words", 0)),
676
+ max_words=int(raw_entry.get("max_words", 0)),
677
+ min_alpha_ratio=float(raw_entry.get("min_alpha_ratio", 0.0)),
678
+ allowed_languages=allowed_languages,
679
+ records=tuple(raw_records) if isinstance(raw_records, list) else (),
680
+ streaming=bool(raw_entry.get("streaming", True)),
681
+ trust_remote_code=bool(raw_entry.get("trust_remote_code", False)),
682
+ )
683
+ )
684
+ return entries
685
+
686
+
687
+ def _iter_hf_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]:
688
+ try:
689
+ from datasets import load_dataset
690
+ except ModuleNotFoundError:
691
+ user_site = site.getusersitepackages()
692
+ if user_site and user_site not in sys.path:
693
+ sys.path.append(user_site)
694
+ from datasets import load_dataset
695
+
696
+ dataset_kwargs: dict[str, object] = {
697
+ "split": entry.split,
698
+ "streaming": entry.streaming,
699
+ }
700
+ if entry.config:
701
+ dataset_kwargs["name"] = entry.config
702
+ if entry.trust_remote_code:
703
+ dataset_kwargs["trust_remote_code"] = True
704
+
705
+ for row in load_dataset(entry.dataset, **dataset_kwargs):
706
+ yield dict(row)
707
+
708
+
709
+ def _iter_file_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]:
710
+ raw_path = entry.path or entry.dataset
711
+ if not raw_path:
712
+ raise ValueError("File corpus plan entries must provide a path.")
713
+ path = Path(raw_path)
714
+ suffix = path.suffix.lower()
715
+ if suffix == ".jsonl":
716
+ with path.open("r", encoding="utf-8") as handle:
717
+ for line in handle:
718
+ if line.strip():
719
+ row = json.loads(line)
720
+ yield row if isinstance(row, dict) else {"text": str(row)}
721
+ return
722
+ if suffix == ".json":
723
+ payload = json.loads(path.read_text(encoding="utf-8"))
724
+ if isinstance(payload, list):
725
+ for row in payload:
726
+ yield row if isinstance(row, dict) else {"text": str(row)}
727
+ return
728
+ if isinstance(payload, dict):
729
+ rows = payload.get("records", payload.get("texts"))
730
+ if isinstance(rows, list):
731
+ for row in rows:
732
+ yield row if isinstance(row, dict) else {"text": str(row)}
733
+ return
734
+ yield payload
735
+ return
736
+ if suffix in {".txt", ".md", ".text"}:
737
+ yield {"text": path.read_text(encoding="utf-8")}
738
+ return
739
+ raise ValueError(f"Unsupported file corpus source: {path}")
740
+
741
+
742
+ def iter_corpus_plan_documents(plan: Iterable[CorpusPlanEntry]) -> Iterator[StreamDocument]:
743
+ for entry in plan:
744
+ accepted = 0
745
+ attempts = 0
746
+ while True:
747
+ accepted_seen_this_attempt = 0
748
+ try:
749
+ if entry.source == "inline":
750
+ row_iterator = (
751
+ item if isinstance(item, dict) else {"text": str(item)}
752
+ for item in entry.records
753
+ )
754
+ elif entry.source == "hf":
755
+ row_iterator = _iter_hf_rows(entry)
756
+ elif entry.source == "file":
757
+ row_iterator = _iter_file_rows(entry)
758
+ else:
759
+ raise ValueError(f"Unsupported corpus plan source: {entry.source}")
760
+
761
+ for row in row_iterator:
762
+ language = _row_language(row)
763
+ _, rejected_text = _extract_preference_pair(row)
764
+ text = clean_training_text(_extract_row_text(row, entry.text_field))
765
+ if not _passes_text_quality(text, language, entry):
766
+ continue
767
+ accepted_seen_this_attempt += 1
768
+ if accepted_seen_this_attempt <= accepted:
769
+ continue
770
+ yield StreamDocument(
771
+ text=text,
772
+ weight=entry.weight,
773
+ source=entry.name,
774
+ language=language,
775
+ preference_rejected_text=rejected_text,
776
+ )
777
+ accepted += 1
778
+ if entry.limit > 0 and accepted >= entry.limit:
779
+ break
780
+ break
781
+ except Exception as exc:
782
+ if entry.source != "hf":
783
+ raise
784
+ if attempts >= HF_STREAM_MAX_RETRIES:
785
+ print(
786
+ f"[source] {entry.name} skipped after {attempts} retries; "
787
+ f"accepted {accepted} documents before final error: {exc}"
788
+ )
789
+ break
790
+ attempts += 1
791
+ delay = min(
792
+ 15.0,
793
+ HF_STREAM_RETRY_BASE_DELAY_SECONDS * (2 ** (attempts - 1)),
794
+ )
795
+ print(
796
+ f"[source] {entry.name} stream interrupted after {accepted} accepted "
797
+ f"documents; retry {attempts}/{HF_STREAM_MAX_RETRIES} in {delay:.2f}s: {exc}"
798
+ )
799
+ time.sleep(delay)
800
+
801
+
802
+ def _log_progress(label: str, processed: int, log_every: int) -> None:
803
+ if log_every > 0 and processed % log_every == 0:
804
+ print(f"[{label}] processed {processed} documents")
805
+
806
+
807
+ def _answer_boundary(tokens: list[str]) -> int | None:
808
+ try:
809
+ return tokens.index("<answer>")
810
+ except ValueError:
811
+ return None
812
+
813
+
814
+ def _weighted_text_parts_for_statistics(text: str, document_weight: float) -> list[tuple[str, float]]:
815
+ if "<answer>" not in text:
816
+ return [(text, document_weight)]
817
+ context, answer = text.split("<answer>", 1)
818
+ context = clean_context_text(context.replace("<reason>", " "))
819
+ answer = clean_answer_text(answer)
820
+ parts: list[tuple[str, float]] = []
821
+ if context:
822
+ parts.append((context, document_weight * CONTEXT_STAT_WEIGHT))
823
+ if answer:
824
+ parts.append((answer, document_weight * ANSWER_READOUT_WEIGHT))
825
+ return parts or [(text, document_weight)]
826
+
827
+
828
+ def _weighted_token_sequences_for_statistics(
829
+ tokens: list[str],
830
+ tokenizer: NativeTokenizer,
831
+ document_weight: float,
832
+ ) -> list[tuple[list[str], float]]:
833
+ answer_index = _answer_boundary(tokens)
834
+ if answer_index is None:
835
+ sequence = [token for token in tokens if token not in tokenizer.special_tokens]
836
+ return [(sequence, document_weight)] if sequence else []
837
+ context_tokens = [
838
+ token for token in tokens[:answer_index] if token not in tokenizer.special_tokens
839
+ ]
840
+ answer_tokens = [
841
+ token for token in tokens[answer_index + 1 :] if token not in tokenizer.special_tokens
842
+ ]
843
+ sequences: list[tuple[list[str], float]] = []
844
+ if context_tokens:
845
+ sequences.append((context_tokens, document_weight * CONTEXT_STAT_WEIGHT))
846
+ if answer_tokens:
847
+ sequences.append((answer_tokens, document_weight * ANSWER_READOUT_WEIGHT))
848
+ return sequences
849
+
850
+
851
+ def _readout_weight_for_target(
852
+ answer_index: int | None,
853
+ target_index: int,
854
+ document_weight: float,
855
+ ) -> float:
856
+ if answer_index is None:
857
+ return document_weight * PLAIN_TEXT_READOUT_WEIGHT
858
+ if target_index <= answer_index:
859
+ return document_weight * CONTEXT_READOUT_WEIGHT
860
+ return document_weight * ANSWER_READOUT_WEIGHT
861
+
862
+
863
+ def _answer_payload_tokens(tokens: list[str], tokenizer: NativeTokenizer) -> list[str]:
864
+ answer_index = _answer_boundary(tokens)
865
+ payload = tokens[answer_index + 1 :] if answer_index is not None else tokens
866
+ return [token for token in payload if token not in tokenizer.special_tokens]
867
+
868
+
869
+ def _standardized_preference_bias(values: object, active_mask: object | None = None) -> list[float]:
870
+ if np is not None:
871
+ bias = np.asarray(values, dtype=np.float64)
872
+ if bias.size == 0:
873
+ return []
874
+ active = (
875
+ np.asarray(active_mask, dtype=bool)
876
+ if active_mask is not None
877
+ else np.ones(bias.shape, dtype=bool)
878
+ )
879
+ if not np.any(active):
880
+ return [0.0 for _ in range(int(bias.size))]
881
+ active_values = bias[active]
882
+ spread = float(active_values.std())
883
+ if spread <= 1e-12:
884
+ return [0.0 for _ in range(int(bias.size))]
885
+ standardized = np.zeros_like(bias, dtype=np.float64)
886
+ standardized[active] = (
887
+ (active_values - float(active_values.mean())) / spread
888
+ ) * PREFERENCE_BIAS_SCALE
889
+ return np.clip(standardized, -2.5, 2.5).astype(float).tolist()
890
+ raw_values = [float(value) for value in values]
891
+ if not raw_values:
892
+ return []
893
+ average = sum(raw_values) / len(raw_values)
894
+ variance = sum((value - average) * (value - average) for value in raw_values) / len(raw_values)
895
+ spread = variance**0.5
896
+ if spread <= 1e-12:
897
+ return [0.0 for _ in raw_values]
898
+ active_indices = (
899
+ [
900
+ index
901
+ for index, active in enumerate(active_mask)
902
+ if active
903
+ ]
904
+ if active_mask is not None
905
+ else list(range(len(raw_values)))
906
+ )
907
+ if not active_indices:
908
+ return [0.0 for _ in raw_values]
909
+ active_values = [raw_values[index] for index in active_indices]
910
+ average = mean(active_values)
911
+ spread = (mean([(value - average) * (value - average) for value in active_values])) ** 0.5
912
+ if spread <= 1e-12:
913
+ return [0.0 for _ in raw_values]
914
+ standardized = [0.0 for _ in raw_values]
915
+ for index in active_indices:
916
+ standardized[index] = max(
917
+ -2.5,
918
+ min(2.5, ((raw_values[index] - average) / spread) * PREFERENCE_BIAS_SCALE),
919
+ )
920
+ return standardized
921
+
922
+
923
+ def _candidate_preference_bias_from_state_vector(
924
+ model: ReframrModel,
925
+ preference_state: object,
926
+ ) -> object:
927
+ if np is None:
928
+ return None
929
+ assert model.embedding_model is not None
930
+ assert model.memory_units is not None
931
+ assert model.ternary_mask is not None
932
+
933
+ embeddings = np.asarray(model.embedding_model.embeddings, dtype=np.float64)
934
+ if embeddings.size == 0:
935
+ return np.zeros(0, dtype=np.float64)
936
+ state_vector = np.asarray(preference_state, dtype=np.float64)
937
+ mask = np.asarray(model.ternary_mask, dtype=np.float64) * float(model.ternary_scale)
938
+ if state_vector.shape[0] != mask.shape[0]:
939
+ return np.zeros(embeddings.shape[0], dtype=np.float64)
940
+
941
+ state_indices = np.arange(model.config.state_dim, dtype=np.int64)
942
+ drive = (
943
+ embeddings[:, state_indices % model.config.embedding_dim]
944
+ + (0.5 * embeddings[:, (3 * state_indices + 1) % model.config.embedding_dim])
945
+ - (0.25 * embeddings[:, (5 * state_indices + 2) % model.config.embedding_dim])
946
+ )
947
+ scores = np.zeros(embeddings.shape[0], dtype=np.float64)
948
+ offset = 0
949
+ for unit in model.memory_units:
950
+ hidden_end = offset + model.config.state_dim
951
+ trace_end = hidden_end + model.config.embedding_dim
952
+ hidden_pref = state_vector[offset:hidden_end] * mask[offset:hidden_end]
953
+ trace_pref = state_vector[hidden_end:trace_end] * mask[hidden_end:trace_end]
954
+ hidden_delta_axis = np.asarray(unit.input_projection, dtype=np.float64) * hidden_pref
955
+ trace_gain = 1.0 - (1.0 / (1.0 + unit.timescale))
956
+ scores += drive @ hidden_delta_axis
957
+ scores += embeddings @ (trace_gain * trace_pref)
958
+ offset = trace_end
959
+ return scores
960
+
961
+
962
+ def _derive_preference_bias_from_pairs(
963
+ model: ReframrModel,
964
+ preference_token_pairs: list[tuple[list[str], list[str], float]],
965
+ tokenizer: NativeTokenizer,
966
+ ) -> tuple[list[float], int]:
967
+ assert model.embedding_model is not None
968
+ vocab_size = len(model.embedding_model.id_to_token)
969
+ if not preference_token_pairs:
970
+ return [0.0 for _ in range(vocab_size)], 0
971
+
972
+ if np is not None:
973
+ token_bias = np.zeros(vocab_size, dtype=np.float64)
974
+ active_token_mask = np.zeros(vocab_size, dtype=bool)
975
+ state_delta = np.zeros(model._combined_state_width(), dtype=np.float64)
976
+ else:
977
+ token_bias = [0.0 for _ in range(vocab_size)]
978
+ active_token_ids: set[int] = set()
979
+ state_delta = [0.0 for _ in range(model._combined_state_width())]
980
+ pair_weight_total = 0.0
981
+ state_pair_count = 0
982
+ state_stride = max(
983
+ 1,
984
+ (len(preference_token_pairs) + MAX_PREFERENCE_STATE_PAIRS - 1)
985
+ // MAX_PREFERENCE_STATE_PAIRS,
986
+ )
987
+
988
+ for pair_index, (chosen_tokens, rejected_tokens, pair_weight) in enumerate(preference_token_pairs):
989
+ chosen_answer = _answer_payload_tokens(chosen_tokens, tokenizer)
990
+ rejected_answer = _answer_payload_tokens(rejected_tokens, tokenizer)
991
+ if chosen_answer:
992
+ delta = pair_weight / max(1, len(chosen_answer))
993
+ for token in chosen_answer:
994
+ token_id = model.embedding_model.token_to_id.get(token)
995
+ if token_id is not None:
996
+ token_bias[token_id] += delta
997
+ if np is not None:
998
+ active_token_mask[token_id] = True
999
+ else:
1000
+ active_token_ids.add(token_id)
1001
+ if rejected_answer:
1002
+ delta = pair_weight / max(1, len(rejected_answer))
1003
+ for token in rejected_answer:
1004
+ token_id = model.embedding_model.token_to_id.get(token)
1005
+ if token_id is not None:
1006
+ token_bias[token_id] -= delta
1007
+ if np is not None:
1008
+ active_token_mask[token_id] = True
1009
+ else:
1010
+ active_token_ids.add(token_id)
1011
+
1012
+ if pair_index % state_stride != 0 or state_pair_count >= MAX_PREFERENCE_STATE_PAIRS:
1013
+ continue
1014
+ chosen_state = model._masked_decode_state(model._build_decode_state(chosen_tokens))
1015
+ rejected_state = model._masked_decode_state(model._build_decode_state(rejected_tokens))
1016
+ if len(chosen_state) != len(rejected_state):
1017
+ continue
1018
+ pair_weight_total += pair_weight
1019
+ state_pair_count += 1
1020
+ if np is not None:
1021
+ state_delta += pair_weight * (
1022
+ np.asarray(chosen_state, dtype=np.float64)
1023
+ - np.asarray(rejected_state, dtype=np.float64)
1024
+ )
1025
+ else:
1026
+ for index, (chosen_value, rejected_value) in enumerate(zip(chosen_state, rejected_state)):
1027
+ state_delta[index] += pair_weight * (chosen_value - rejected_value)
1028
+
1029
+ if pair_weight_total > 0.0:
1030
+ if np is not None:
1031
+ state_delta = state_delta / pair_weight_total
1032
+ candidate_bias = _candidate_preference_bias_from_state_vector(model, state_delta)
1033
+ if candidate_bias is not None:
1034
+ token_bias[active_token_mask] = (
1035
+ token_bias[active_token_mask] + candidate_bias[active_token_mask]
1036
+ )
1037
+ else:
1038
+ state_delta = [value / pair_weight_total for value in state_delta]
1039
+ if np is not None:
1040
+ return _standardized_preference_bias(token_bias, active_token_mask), state_pair_count
1041
+ active_mask = [index in active_token_ids for index in range(vocab_size)]
1042
+ return _standardized_preference_bias(token_bias, active_mask), state_pair_count
1043
+
1044
+
1045
+ def _solve_weighted_prompt_readout(
1046
+ states: list[Vector],
1047
+ labels: list[int],
1048
+ weights: list[float],
1049
+ *,
1050
+ vocab_size: int,
1051
+ diagonal: object,
1052
+ state_offset: object,
1053
+ regularization: float,
1054
+ ) -> tuple[object, object, int]:
1055
+ if np is None or not states or not labels or not weights:
1056
+ return [], [0.0 for _ in range(vocab_size)], 0
1057
+ state_matrix = np.asarray(states, dtype=np.float64)
1058
+ label_array = np.asarray(labels, dtype=np.int64)
1059
+ weight_vector = np.asarray(weights, dtype=np.float64)
1060
+ valid_mask = (
1061
+ (label_array >= 0)
1062
+ & (label_array < vocab_size)
1063
+ & (weight_vector > 0.0)
1064
+ )
1065
+ if not np.any(valid_mask):
1066
+ return [], [0.0 for _ in range(vocab_size)], 0
1067
+ state_matrix = state_matrix[valid_mask]
1068
+ label_array = label_array[valid_mask]
1069
+ weight_vector = weight_vector[valid_mask]
1070
+ diagonal_array = np.asarray(diagonal, dtype=np.float64)
1071
+ offset_array = np.asarray(state_offset, dtype=np.float64)
1072
+ if (
1073
+ len(state_matrix.shape) != 2
1074
+ or diagonal_array.shape[0] != state_matrix.shape[1]
1075
+ or offset_array.shape[0] != state_matrix.shape[1]
1076
+ ):
1077
+ return [], [0.0 for _ in range(vocab_size)], 0
1078
+ masked_states = state_matrix * diagonal_array[None, :]
1079
+ centered_states = masked_states - offset_array[None, :]
1080
+ weighted_centered_states = weight_vector[:, None] * centered_states
1081
+ gram = centered_states.T @ weighted_centered_states
1082
+ cross = np.zeros((vocab_size, centered_states.shape[1]), dtype=np.float64)
1083
+ np.add.at(cross, label_array, weighted_centered_states)
1084
+ total_weight = float(weight_vector.sum())
1085
+ if total_weight <= 0.0:
1086
+ return [], [0.0 for _ in range(vocab_size)], 0
1087
+ bias = np.zeros(vocab_size, dtype=np.float64)
1088
+ np.add.at(bias, label_array, weight_vector)
1089
+ bias /= total_weight
1090
+ readout = ridge_regression_readout_from_moments(
1091
+ gram,
1092
+ cross,
1093
+ regularization=regularization,
1094
+ )
1095
+ return readout, bias, int(label_array.shape[0])
1096
+
1097
+
1098
+ def fit_model_from_corpus_plan(
1099
+ plan: Iterable[CorpusPlanEntry],
1100
+ config: ReframrConfig,
1101
+ *,
1102
+ log_every: int = 0,
1103
+ ) -> tuple[ReframrModel, dict[str, object]]:
1104
+ entries = list(plan)
1105
+ if not entries:
1106
+ raise ValueError("Cannot fit REFRAMR without any corpus plan entries.")
1107
+ stage_seconds: dict[str, float] = {}
1108
+ stage_started = time.perf_counter()
1109
+
1110
+ def finish_stage(name: str) -> None:
1111
+ nonlocal stage_started
1112
+ now = time.perf_counter()
1113
+ elapsed = round(now - stage_started, 6)
1114
+ stage_seconds[name] = elapsed
1115
+ if log_every > 0:
1116
+ print(f"[stage] {name} finished in {elapsed:.3f}s")
1117
+ stage_started = now
1118
+
1119
+ seed_tokenizer = NativeTokenizer(
1120
+ merges=[],
1121
+ vocab=[],
1122
+ base_symbols=[],
1123
+ lowercase=config.lowercase,
1124
+ )
1125
+ segment_counts: Counter[str] = Counter()
1126
+ source_counts: dict[str, int] = {}
1127
+ documents: list[StreamDocument] = []
1128
+ processed = 0
1129
+ for entry in entries:
1130
+ if log_every > 0:
1131
+ print(f"[source] {entry.name} started")
1132
+ source_start = processed
1133
+ for document in iter_corpus_plan_documents([entry]):
1134
+ documents.append(document)
1135
+ processed += 1
1136
+ source_counts[document.source] = source_counts.get(document.source, 0) + 1
1137
+ for text_part, part_weight in _weighted_text_parts_for_statistics(
1138
+ document.text,
1139
+ document.weight,
1140
+ ):
1141
+ for segment in seed_tokenizer.pretokenize(text_part):
1142
+ segment_counts[segment] += part_weight
1143
+ if document.preference_rejected_text:
1144
+ rejected_weight = document.weight * PREFERENCE_REJECTED_TOKENIZER_WEIGHT
1145
+ for text_part, part_weight in _weighted_text_parts_for_statistics(
1146
+ document.preference_rejected_text,
1147
+ rejected_weight,
1148
+ ):
1149
+ for segment in seed_tokenizer.pretokenize(text_part):
1150
+ segment_counts[segment] += part_weight
1151
+ _log_progress("tokenizer", processed, log_every)
1152
+ if log_every > 0:
1153
+ print(f"[source] {entry.name} accepted {processed - source_start} documents")
1154
+ if processed == 0:
1155
+ raise ValueError("Corpus plan did not yield any usable documents after filtering.")
1156
+ finish_stage("stream_and_segment")
1157
+ tokenizer = NativeTokenizer.train_from_segment_counts(
1158
+ segment_counts,
1159
+ vocab_size=config.tokenizer_vocab_size,
1160
+ min_pair_frequency=config.tokenizer_min_pair_frequency,
1161
+ lowercase=config.lowercase,
1162
+ )
1163
+ finish_stage("tokenizer_fit")
1164
+
1165
+ token_counts: Counter[str] = Counter()
1166
+ raw_tokenized_documents: list[list[str]] = []
1167
+ raw_rejected_tokenized_documents: list[list[str]] = []
1168
+ processed = 0
1169
+ for document in documents:
1170
+ processed += 1
1171
+ tokens = tokenizer.encode(document.text)
1172
+ raw_tokenized_documents.append(tokens)
1173
+ for token in tokens:
1174
+ if token in tokenizer.special_tokens:
1175
+ token_counts[token] += document.weight
1176
+ for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
1177
+ tokens,
1178
+ tokenizer,
1179
+ document.weight,
1180
+ ):
1181
+ for token in token_sequence:
1182
+ token_counts[token] += sequence_weight
1183
+ rejected_tokens = (
1184
+ tokenizer.encode(document.preference_rejected_text)
1185
+ if document.preference_rejected_text
1186
+ else []
1187
+ )
1188
+ raw_rejected_tokenized_documents.append(rejected_tokens)
1189
+ rejected_weight = document.weight * PREFERENCE_REJECTED_TOKENIZER_WEIGHT
1190
+ for token in rejected_tokens:
1191
+ if token in tokenizer.special_tokens:
1192
+ token_counts[token] += rejected_weight
1193
+ for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
1194
+ rejected_tokens,
1195
+ tokenizer,
1196
+ rejected_weight,
1197
+ ):
1198
+ for token in token_sequence:
1199
+ token_counts[token] += sequence_weight
1200
+ _log_progress("vocab", processed, log_every)
1201
+ token_to_id, id_to_token = build_vocabulary_from_counts(
1202
+ token_counts,
1203
+ min_frequency=config.min_frequency,
1204
+ max_vocab=config.max_vocab,
1205
+ )
1206
+ if not id_to_token:
1207
+ raise ValueError("Streaming recompute could not derive an embedding vocabulary.")
1208
+ finish_stage("vocabulary")
1209
+
1210
+ cooccurrence = StreamingCooccurrenceAccumulator(token_to_id, config.window_size)
1211
+ tokenized_documents: list[list[str]] = []
1212
+ preference_token_pairs: list[tuple[list[str], list[str], float]] = []
1213
+ processed = 0
1214
+ for document, raw_tokens, raw_rejected_tokens in zip(
1215
+ documents,
1216
+ raw_tokenized_documents,
1217
+ raw_rejected_tokenized_documents,
1218
+ ):
1219
+ processed += 1
1220
+ tokens = [token for token in raw_tokens if token in token_to_id]
1221
+ tokenized_documents.append(tokens)
1222
+ rejected_tokens = [token for token in raw_rejected_tokens if token in token_to_id]
1223
+ if len(tokens) > 1 and len(rejected_tokens) > 1:
1224
+ preference_token_pairs.append((tokens, rejected_tokens, document.weight))
1225
+ for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
1226
+ tokens,
1227
+ tokenizer,
1228
+ document.weight,
1229
+ ):
1230
+ if len(token_sequence) > 1:
1231
+ cooccurrence.update_tokens(token_sequence, weight=sequence_weight)
1232
+ _log_progress("cooccurrence", processed, log_every)
1233
+ finish_stage("cooccurrence")
1234
+ if np is not None:
1235
+ embedding_model = fit_randomized_ppmi_embedding_from_counts(
1236
+ id_to_token,
1237
+ cooccurrence.rows,
1238
+ embedding_dim=config.embedding_dim,
1239
+ )
1240
+ else:
1241
+ embedding_model = fit_ppmi_embedding_from_cooccurrence(
1242
+ id_to_token,
1243
+ cooccurrence.to_sparse(),
1244
+ embedding_dim=config.embedding_dim,
1245
+ )
1246
+ finish_stage("embedding")
1247
+
1248
+ model = ReframrModel(config)
1249
+ model.tokenizer = tokenizer
1250
+ model.embedding_model = embedding_model
1251
+ model.memory_units = [
1252
+ AnalyticalMemoryUnit(config.state_dim, timescale)
1253
+ for timescale in config.timescales
1254
+ ]
1255
+ model.trace_token_weights = model._derive_trace_token_weights_from_counts(token_counts)
1256
+
1257
+ feature_count = len(model._zero_combined_state())
1258
+ if np is not None:
1259
+ feature_second_moment = np.zeros(feature_count, dtype=np.float64)
1260
+ raw_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64)
1261
+ else:
1262
+ feature_second_moment = zeros_vector(feature_count)
1263
+ raw_cross = zeros(len(embedding_model.id_to_token), feature_count)
1264
+ example_weight_total = 0.0
1265
+ has_answer_targets = any(_answer_boundary(tokens) is not None for tokens in tokenized_documents)
1266
+ if config.max_training_examples is None:
1267
+ answer_reservoir_capacity = None
1268
+ general_reservoir_capacity = None
1269
+ elif config.max_training_examples <= 0:
1270
+ answer_reservoir_capacity = 0
1271
+ general_reservoir_capacity = 0
1272
+ elif has_answer_targets:
1273
+ answer_reservoir_capacity = max(1, int(config.max_training_examples * 0.75))
1274
+ general_reservoir_capacity = max(0, config.max_training_examples - answer_reservoir_capacity)
1275
+ else:
1276
+ answer_reservoir_capacity = 0
1277
+ general_reservoir_capacity = config.max_training_examples
1278
+ answer_sequence_capacity = MAX_ANSWER_SEQUENCE_EXAMPLES if has_answer_targets else 0
1279
+ answer_reservoir = StateReservoir(answer_reservoir_capacity, seed=17)
1280
+ general_reservoir = StateReservoir(general_reservoir_capacity, seed=13)
1281
+ answer_intent_reservoir = StateReservoir(answer_reservoir_capacity, seed=29)
1282
+ answer_start_reservoir = StateReservoir(answer_reservoir_capacity, seed=37)
1283
+ answer_sequence_reservoir = SequenceReservoir(answer_sequence_capacity, seed=41)
1284
+ moment_reservoir = StateReservoir(
1285
+ config.max_training_examples if config.max_training_examples is not None else None,
1286
+ seed=31,
1287
+ )
1288
+ transitions = TransitionAccumulator(
1289
+ max_contexts_per_order=config.max_transition_contexts_per_order,
1290
+ max_next_tokens=config.max_transition_next_tokens,
1291
+ )
1292
+ if np is not None:
1293
+ target_label_mass = np.zeros(len(embedding_model.id_to_token), dtype=np.float64)
1294
+ else:
1295
+ target_label_mass = zeros_vector(len(embedding_model.id_to_token))
1296
+ for document, tokens in zip(documents, tokenized_documents):
1297
+ answer_index = _answer_boundary(tokens)
1298
+ for index in range(len(tokens) - 1):
1299
+ next_token = tokens[index + 1]
1300
+ if tokenizer is not None and next_token in tokenizer.special_tokens:
1301
+ continue
1302
+ next_token_id = embedding_model.token_to_id.get(next_token, -1)
1303
+ if next_token_id < 0:
1304
+ continue
1305
+ label_weight = _readout_weight_for_target(answer_index, index + 1, document.weight)
1306
+ if label_weight > 0.0:
1307
+ target_label_mass[next_token_id] += label_weight
1308
+ if np is not None:
1309
+ positive_label_mass = target_label_mass[target_label_mass > 0.0]
1310
+ reference_label_mass = (
1311
+ float(np.median(positive_label_mass))
1312
+ if positive_label_mass.size
1313
+ else 1.0
1314
+ )
1315
+ target_balance = np.ones(len(embedding_model.id_to_token), dtype=np.float64)
1316
+ np.divide(
1317
+ reference_label_mass,
1318
+ np.maximum(target_label_mass, 1e-12),
1319
+ out=target_balance,
1320
+ where=target_label_mass > 0.0,
1321
+ )
1322
+ target_balance = np.clip(np.sqrt(target_balance), 0.25, 4.0)
1323
+ else:
1324
+ positive_label_mass = [value for value in target_label_mass if value > 0.0]
1325
+ if positive_label_mass:
1326
+ sorted_mass = sorted(positive_label_mass)
1327
+ reference_label_mass = sorted_mass[len(sorted_mass) // 2]
1328
+ else:
1329
+ reference_label_mass = 1.0
1330
+ target_balance = [
1331
+ max(0.25, min(4.0, (reference_label_mass / max(value, 1e-12)) ** 0.5))
1332
+ if value > 0.0
1333
+ else 1.0
1334
+ for value in target_label_mass
1335
+ ]
1336
+ processed = 0
1337
+ embedding_array = (
1338
+ np.asarray(embedding_model.embeddings, dtype=RUNTIME_ARRAY_DTYPE)
1339
+ if np is not None
1340
+ else None
1341
+ )
1342
+ trace_embedding_array = (
1343
+ model._build_trace_embedding_table_array(embedding_array)
1344
+ if np is not None and embedding_array is not None
1345
+ else None
1346
+ )
1347
+ if np is not None:
1348
+ trace_decay = np.asarray(
1349
+ [1.0 / (1.0 + unit.timescale) for unit in model.memory_units],
1350
+ dtype=RUNTIME_ARRAY_DTYPE,
1351
+ )
1352
+ trace_gain = 1.0 - trace_decay
1353
+ transition_stack = np.asarray(
1354
+ [unit.transition for unit in model.memory_units],
1355
+ dtype=RUNTIME_ARRAY_DTYPE,
1356
+ )
1357
+ input_projection_stack = np.asarray(
1358
+ [unit.input_projection for unit in model.memory_units],
1359
+ dtype=RUNTIME_ARRAY_DTYPE,
1360
+ )
1361
+ drive_indices = np.arange(config.state_dim, dtype=np.int64)
1362
+ drive_primary = drive_indices % config.embedding_dim
1363
+ drive_secondary = (3 * drive_indices + 1) % config.embedding_dim
1364
+ drive_tertiary = (5 * drive_indices + 2) % config.embedding_dim
1365
+ else:
1366
+ trace_decay = None
1367
+ trace_gain = None
1368
+ transition_stack = None
1369
+ input_projection_stack = None
1370
+ drive_primary = None
1371
+ drive_secondary = None
1372
+ drive_tertiary = None
1373
+ for document, tokens in zip(documents, tokenized_documents):
1374
+ processed += 1
1375
+ if len(tokens) < 2:
1376
+ _log_progress("state", processed, log_every)
1377
+ continue
1378
+
1379
+ answer_index = _answer_boundary(tokens)
1380
+ for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
1381
+ tokens,
1382
+ tokenizer,
1383
+ document.weight,
1384
+ ):
1385
+ if len(token_sequence) > 1:
1386
+ transitions.update_tokens(token_sequence, weight=sequence_weight)
1387
+ if np is not None:
1388
+ hidden_state_matrix = np.zeros((len(config.timescales), config.state_dim), dtype=RUNTIME_ARRAY_DTYPE)
1389
+ context_trace_matrix = np.zeros((len(config.timescales), config.embedding_dim), dtype=RUNTIME_ARRAY_DTYPE)
1390
+ else:
1391
+ hidden_states = [zeros_vector(config.state_dim) for _ in config.timescales]
1392
+ context_traces = [zeros_vector(config.embedding_dim) for _ in config.timescales]
1393
+ answer_anchor_state = None
1394
+ for index in range(len(tokens) - 1):
1395
+ token = tokens[index]
1396
+ token_id = embedding_model.token_to_id.get(token, -1)
1397
+ if (
1398
+ np is not None
1399
+ and embedding_array is not None
1400
+ and trace_decay is not None
1401
+ and trace_gain is not None
1402
+ and transition_stack is not None
1403
+ and input_projection_stack is not None
1404
+ and drive_primary is not None
1405
+ and drive_secondary is not None
1406
+ and drive_tertiary is not None
1407
+ and trace_embedding_array is not None
1408
+ and token_id >= 0
1409
+ ):
1410
+ embedding = embedding_array[token_id]
1411
+ trace_embedding = trace_embedding_array[token_id]
1412
+ drive = (
1413
+ embedding[drive_primary]
1414
+ + (0.5 * embedding[drive_secondary])
1415
+ - (0.25 * embedding[drive_tertiary])
1416
+ )
1417
+ hidden_state_matrix = (
1418
+ (transition_stack @ hidden_state_matrix[:, :, None])[:, :, 0]
1419
+ + (input_projection_stack * drive[None, :])
1420
+ )
1421
+ context_trace_matrix = (
1422
+ context_trace_matrix + (trace_gain[:, None] * trace_embedding[None, :])
1423
+ )
1424
+ else:
1425
+ hidden_states, context_traces, combined_state = model._step_hidden_states(
1426
+ hidden_states,
1427
+ context_traces,
1428
+ token,
1429
+ )
1430
+ if token == "<answer>":
1431
+ if np is not None:
1432
+ answer_anchor_state = np.concatenate(
1433
+ (hidden_state_matrix, context_trace_matrix),
1434
+ axis=1,
1435
+ ).reshape(-1).copy()
1436
+ else:
1437
+ answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:]
1438
+ next_token = tokens[index + 1]
1439
+ if next_token in tokenizer.special_tokens:
1440
+ continue
1441
+ next_token_id = embedding_model.token_to_id.get(next_token, -1)
1442
+ if next_token_id < 0:
1443
+ continue
1444
+ raw_readout_weight = _readout_weight_for_target(answer_index, index + 1, document.weight)
1445
+ readout_weight = raw_readout_weight * float(target_balance[next_token_id])
1446
+ if readout_weight <= 0.0:
1447
+ continue
1448
+ moment_slot = moment_reservoir.reserve_slot(weight=readout_weight)
1449
+ is_answer_target = answer_index is not None and index + 1 > answer_index
1450
+ target_reservoir = answer_reservoir if is_answer_target else general_reservoir
1451
+ memory_weight = readout_weight * float(target_balance[next_token_id])
1452
+ answer_token_offset = (
1453
+ index - answer_index
1454
+ if is_answer_target and answer_index is not None
1455
+ else None
1456
+ )
1457
+ intent_slot = (
1458
+ answer_intent_reservoir.reserve_slot(weight=memory_weight)
1459
+ if is_answer_target and answer_anchor_state is not None
1460
+ else None
1461
+ )
1462
+ answer_start_weight = (
1463
+ raw_readout_weight * (ANSWER_START_DECAY ** answer_token_offset)
1464
+ if (
1465
+ answer_token_offset is not None
1466
+ and answer_token_offset < ANSWER_START_TOKEN_WINDOW
1467
+ )
1468
+ else 0.0
1469
+ )
1470
+ answer_start_slot = (
1471
+ answer_start_reservoir.reserve_slot(weight=answer_start_weight)
1472
+ if answer_start_weight > 0.0 and answer_anchor_state is not None
1473
+ else None
1474
+ )
1475
+ if np is not None:
1476
+ reservoir_slot = target_reservoir.reserve_slot(weight=memory_weight)
1477
+ if moment_slot is not None or reservoir_slot is not None:
1478
+ combined_state = np.concatenate(
1479
+ (hidden_state_matrix, context_trace_matrix),
1480
+ axis=1,
1481
+ ).reshape(-1).copy()
1482
+ if moment_slot is not None:
1483
+ moment_reservoir.store_reserved(
1484
+ moment_slot,
1485
+ combined_state,
1486
+ next_token_id,
1487
+ example_weight=readout_weight,
1488
+ )
1489
+ if reservoir_slot is not None:
1490
+ target_reservoir.store_reserved(reservoir_slot, combined_state, next_token_id)
1491
+ if intent_slot is not None:
1492
+ answer_intent_reservoir.store_reserved(
1493
+ intent_slot,
1494
+ answer_anchor_state,
1495
+ next_token_id,
1496
+ example_weight=memory_weight,
1497
+ )
1498
+ if answer_start_slot is not None:
1499
+ answer_start_reservoir.store_reserved(
1500
+ answer_start_slot,
1501
+ answer_anchor_state,
1502
+ next_token_id,
1503
+ example_weight=answer_start_weight * float(target_balance[next_token_id]),
1504
+ )
1505
+ else:
1506
+ reservoir_slot = target_reservoir.reserve_slot(weight=memory_weight)
1507
+ if moment_slot is None and reservoir_slot is None:
1508
+ continue
1509
+ if moment_slot is not None:
1510
+ moment_reservoir.store_reserved(
1511
+ moment_slot,
1512
+ combined_state,
1513
+ next_token_id,
1514
+ example_weight=readout_weight,
1515
+ )
1516
+ if reservoir_slot is not None:
1517
+ target_reservoir.store_reserved(reservoir_slot, combined_state, next_token_id)
1518
+ if intent_slot is not None:
1519
+ answer_intent_reservoir.store_reserved(
1520
+ intent_slot,
1521
+ answer_anchor_state,
1522
+ next_token_id,
1523
+ example_weight=memory_weight,
1524
+ )
1525
+ if answer_start_slot is not None:
1526
+ answer_start_reservoir.store_reserved(
1527
+ answer_start_slot,
1528
+ answer_anchor_state,
1529
+ next_token_id,
1530
+ example_weight=answer_start_weight * target_balance[next_token_id],
1531
+ )
1532
+ if answer_anchor_state is not None and answer_index is not None:
1533
+ prompt_token_ids = [
1534
+ embedding_model.token_to_id[token]
1535
+ for token in tokens[:answer_index]
1536
+ if token not in tokenizer.special_tokens
1537
+ and token in embedding_model.token_to_id
1538
+ ]
1539
+ answer_token_ids = [
1540
+ embedding_model.token_to_id[token]
1541
+ for token in tokens[answer_index + 1 :]
1542
+ if token not in tokenizer.special_tokens
1543
+ and token in embedding_model.token_to_id
1544
+ ]
1545
+ answer_sequence_reservoir.consider(
1546
+ answer_anchor_state,
1547
+ prompt_token_ids,
1548
+ answer_token_ids,
1549
+ weight=document.weight * ANSWER_READOUT_WEIGHT,
1550
+ )
1551
+ _log_progress("state", processed, log_every)
1552
+
1553
+ moment_states = moment_reservoir.states
1554
+ moment_labels = moment_reservoir.labels
1555
+ moment_weights = moment_reservoir.weights
1556
+ example_weight_total = sum(moment_weights)
1557
+ if np is not None and moment_states:
1558
+ state_matrix = np.asarray(moment_states, dtype=np.float64)
1559
+ weight_vector = np.asarray(moment_weights, dtype=np.float64)
1560
+ weighted_states = weight_vector[:, None] * state_matrix
1561
+ feature_second_moment += (weighted_states * state_matrix).sum(axis=0)
1562
+ np.add.at(raw_cross, moment_labels, weighted_states)
1563
+ elif moment_states:
1564
+ for state, label_id, readout_weight in zip(moment_states, moment_labels, moment_weights):
1565
+ for feature, value in enumerate(state):
1566
+ weighted_value = readout_weight * value
1567
+ feature_second_moment[feature] += weighted_value * value
1568
+ raw_cross[label_id][feature] += weighted_value
1569
+
1570
+ if example_weight_total <= 0.0:
1571
+ raise ValueError("Streaming recompute did not collect any next-token training examples.")
1572
+
1573
+ if np is not None:
1574
+ feature_energy = (feature_second_moment / example_weight_total).tolist()
1575
+ else:
1576
+ feature_energy = [
1577
+ feature_second_moment[index] / example_weight_total
1578
+ for index in range(feature_count)
1579
+ ]
1580
+ ternary_scale, ternary_mask = derive_ternary_mask_from_feature_energy(feature_energy)
1581
+ if np is not None:
1582
+ diagonal = np.asarray([ternary_scale * value for value in ternary_mask], dtype=np.float64)
1583
+ masked_feature_second_moment = feature_second_moment * diagonal * diagonal
1584
+ masked_cross = raw_cross * diagonal[None, :]
1585
+ else:
1586
+ diagonal = [ternary_scale * value for value in ternary_mask]
1587
+ masked_feature_second_moment = [
1588
+ feature_second_moment[index] * diagonal[index] * diagonal[index]
1589
+ for index in range(feature_count)
1590
+ ]
1591
+ masked_cross = [
1592
+ [
1593
+ raw_cross[row][col] * diagonal[col]
1594
+ for col in range(feature_count)
1595
+ ]
1596
+ for row in range(len(raw_cross))
1597
+ ]
1598
+ readout_solver = "diagonal"
1599
+ state_offset_values: object
1600
+ readout_bias_values: object
1601
+ if (
1602
+ np is not None
1603
+ and moment_states
1604
+ and feature_count <= FULL_READOUT_FEATURE_LIMIT
1605
+ and len(moment_states) <= FULL_READOUT_EXAMPLE_LIMIT
1606
+ ):
1607
+ state_matrix = np.asarray(moment_states, dtype=np.float64)
1608
+ weight_vector = np.asarray(moment_weights, dtype=np.float64)
1609
+ label_array = np.asarray(moment_labels, dtype=np.int64)
1610
+ masked_states = state_matrix * diagonal[None, :]
1611
+ total_weight = float(weight_vector.sum())
1612
+ if total_weight <= 0.0:
1613
+ total_weight = 1.0
1614
+ state_offset_values = (weight_vector[:, None] * masked_states).sum(axis=0) / total_weight
1615
+ centered_states = masked_states - state_offset_values[None, :]
1616
+ weighted_centered_states = weight_vector[:, None] * centered_states
1617
+ gram = centered_states.T @ weighted_centered_states
1618
+ full_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64)
1619
+ np.add.at(full_cross, label_array, weighted_centered_states)
1620
+ readout_bias_values = np.zeros(len(embedding_model.id_to_token), dtype=np.float64)
1621
+ np.add.at(readout_bias_values, label_array, weight_vector)
1622
+ readout_bias_values /= total_weight
1623
+ readout_weights = ridge_regression_readout_from_moments(
1624
+ gram,
1625
+ full_cross,
1626
+ regularization=config.regularization,
1627
+ )
1628
+ readout_solver = "full"
1629
+ else:
1630
+ state_offset_values = (
1631
+ np.zeros(feature_count, dtype=np.float64)
1632
+ if np is not None
1633
+ else [0.0 for _ in range(feature_count)]
1634
+ )
1635
+ if np is not None:
1636
+ label_total = max(float(target_label_mass.sum()), 1.0)
1637
+ readout_bias_values = target_label_mass / label_total
1638
+ else:
1639
+ label_total = max(sum(target_label_mass), 1.0)
1640
+ readout_bias_values = [value / label_total for value in target_label_mass]
1641
+ readout_weights = ridge_regression_readout_from_diagonal_moments(
1642
+ masked_feature_second_moment,
1643
+ masked_cross,
1644
+ regularization=config.regularization,
1645
+ )
1646
+ finish_stage("state_and_readout")
1647
+
1648
+ model.ternary_scale = ternary_scale
1649
+ model.ternary_mask = ternary_mask
1650
+ model.readout_weights = readout_weights
1651
+ model.state_offset = (
1652
+ state_offset_values.tolist()
1653
+ if hasattr(state_offset_values, "tolist")
1654
+ else list(state_offset_values)
1655
+ )
1656
+ model.readout_bias = (
1657
+ readout_bias_values.tolist()
1658
+ if hasattr(readout_bias_values, "tolist")
1659
+ else list(readout_bias_values)
1660
+ )
1661
+ model.preference_bias, preference_state_pairs = _derive_preference_bias_from_pairs(
1662
+ model,
1663
+ preference_token_pairs,
1664
+ tokenizer,
1665
+ )
1666
+ finish_stage("preference")
1667
+ reservoir_states = answer_reservoir.states + general_reservoir.states
1668
+ reservoir_labels = answer_reservoir.labels + general_reservoir.labels
1669
+ answer_intent_states = answer_intent_reservoir.states
1670
+ answer_intent_labels = answer_intent_reservoir.labels
1671
+ answer_start_states = answer_start_reservoir.states
1672
+ answer_start_labels = answer_start_reservoir.labels
1673
+ answer_sequence_states = answer_sequence_reservoir.keys
1674
+ answer_sequence_prompt_rows = answer_sequence_reservoir.prompt_rows
1675
+ answer_sequence_rows = answer_sequence_reservoir.token_rows
1676
+ prompt_answer_weights, prompt_answer_bias, prompt_answer_readout_examples = (
1677
+ _solve_weighted_prompt_readout(
1678
+ answer_intent_states,
1679
+ answer_intent_labels,
1680
+ answer_intent_reservoir.weights,
1681
+ vocab_size=len(embedding_model.id_to_token),
1682
+ diagonal=diagonal,
1683
+ state_offset=state_offset_values,
1684
+ regularization=config.regularization,
1685
+ )
1686
+ )
1687
+ (
1688
+ prompt_answer_start_weights,
1689
+ prompt_answer_start_bias,
1690
+ prompt_answer_start_readout_examples,
1691
+ ) = _solve_weighted_prompt_readout(
1692
+ answer_start_states,
1693
+ answer_start_labels,
1694
+ answer_start_reservoir.weights,
1695
+ vocab_size=len(embedding_model.id_to_token),
1696
+ diagonal=diagonal,
1697
+ state_offset=state_offset_values,
1698
+ regularization=config.regularization,
1699
+ )
1700
+ model.prompt_answer_weights = prompt_answer_weights
1701
+ model.prompt_answer_bias = (
1702
+ prompt_answer_bias.tolist()
1703
+ if hasattr(prompt_answer_bias, "tolist")
1704
+ else list(prompt_answer_bias)
1705
+ )
1706
+ model.prompt_answer_start_weights = prompt_answer_start_weights
1707
+ model.prompt_answer_start_bias = (
1708
+ prompt_answer_start_bias.tolist()
1709
+ if hasattr(prompt_answer_start_bias, "tolist")
1710
+ else list(prompt_answer_start_bias)
1711
+ )
1712
+ if np is not None and reservoir_states:
1713
+ reservoir_array = np.asarray(reservoir_states, dtype=RUNTIME_ARRAY_DTYPE)
1714
+ mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
1715
+ offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
1716
+ associative_array = ((reservoir_array * mask_array[None, :]) - offset_array[None, :]).astype(
1717
+ RUNTIME_ARRAY_DTYPE,
1718
+ copy=False,
1719
+ )
1720
+ model.associative_keys = associative_array
1721
+ model.associative_key_norms = np.linalg.norm(associative_array, axis=1).tolist()
1722
+ else:
1723
+ offset_vector = model.state_offset
1724
+ model.associative_keys = [
1725
+ [
1726
+ value - offset_vector[index]
1727
+ for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
1728
+ ]
1729
+ for state in reservoir_states
1730
+ ]
1731
+ model.associative_key_norms = [norm(state) for state in model.associative_keys]
1732
+ model.associative_values = reservoir_labels[:]
1733
+ if np is not None and answer_intent_states:
1734
+ answer_intent_array = np.asarray(answer_intent_states, dtype=RUNTIME_ARRAY_DTYPE)
1735
+ mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
1736
+ offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
1737
+ answer_array = ((answer_intent_array * mask_array[None, :]) - offset_array[None, :]).astype(
1738
+ RUNTIME_ARRAY_DTYPE,
1739
+ copy=False,
1740
+ )
1741
+ model.answer_keys = answer_array
1742
+ model.answer_key_norms = np.linalg.norm(answer_array, axis=1).tolist()
1743
+ else:
1744
+ offset_vector = model.state_offset
1745
+ model.answer_keys = [
1746
+ [
1747
+ value - offset_vector[index]
1748
+ for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
1749
+ ]
1750
+ for state in answer_intent_states
1751
+ ]
1752
+ model.answer_key_norms = [norm(state) for state in model.answer_keys]
1753
+ model.answer_values = answer_intent_labels[:]
1754
+ if np is not None and answer_start_states:
1755
+ answer_start_array = np.asarray(answer_start_states, dtype=RUNTIME_ARRAY_DTYPE)
1756
+ mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
1757
+ offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
1758
+ start_array = ((answer_start_array * mask_array[None, :]) - offset_array[None, :]).astype(
1759
+ RUNTIME_ARRAY_DTYPE,
1760
+ copy=False,
1761
+ )
1762
+ model.answer_start_keys = start_array
1763
+ model.answer_start_key_norms = np.linalg.norm(start_array, axis=1).tolist()
1764
+ else:
1765
+ offset_vector = model.state_offset
1766
+ model.answer_start_keys = [
1767
+ [
1768
+ value - offset_vector[index]
1769
+ for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
1770
+ ]
1771
+ for state in answer_start_states
1772
+ ]
1773
+ model.answer_start_key_norms = [norm(state) for state in model.answer_start_keys]
1774
+ model.answer_start_values = answer_start_labels[:]
1775
+ if np is not None and answer_sequence_states:
1776
+ answer_sequence_array = np.asarray(answer_sequence_states, dtype=RUNTIME_ARRAY_DTYPE)
1777
+ mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
1778
+ offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
1779
+ sequence_array = ((answer_sequence_array * mask_array[None, :]) - offset_array[None, :]).astype(
1780
+ RUNTIME_ARRAY_DTYPE,
1781
+ copy=False,
1782
+ )
1783
+ model.answer_sequence_keys = sequence_array
1784
+ model.answer_sequence_key_norms = np.linalg.norm(sequence_array, axis=1).tolist()
1785
+ else:
1786
+ offset_vector = model.state_offset
1787
+ model.answer_sequence_keys = [
1788
+ [
1789
+ value - offset_vector[index]
1790
+ for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
1791
+ ]
1792
+ for state in answer_sequence_states
1793
+ ]
1794
+ model.answer_sequence_key_norms = [norm(state) for state in model.answer_sequence_keys]
1795
+ if np is not None:
1796
+ padded_answer_sequences = np.full(
1797
+ (len(answer_sequence_rows), MAX_ANSWER_SEQUENCE_TOKENS),
1798
+ -1,
1799
+ dtype=np.int32,
1800
+ )
1801
+ for row_index, row in enumerate(answer_sequence_rows):
1802
+ row_width = min(len(row), MAX_ANSWER_SEQUENCE_TOKENS)
1803
+ if row_width > 0:
1804
+ padded_answer_sequences[row_index, :row_width] = row[:row_width]
1805
+ padded_answer_sequence_prompts = np.full(
1806
+ (len(answer_sequence_prompt_rows), MAX_ANSWER_SEQUENCE_TOKENS),
1807
+ -1,
1808
+ dtype=np.int32,
1809
+ )
1810
+ for row_index, row in enumerate(answer_sequence_prompt_rows):
1811
+ row_width = min(len(row), MAX_ANSWER_SEQUENCE_TOKENS)
1812
+ if row_width > 0:
1813
+ padded_answer_sequence_prompts[row_index, :row_width] = row[:row_width]
1814
+ else:
1815
+ padded_answer_sequences = [
1816
+ row + [-1 for _ in range(MAX_ANSWER_SEQUENCE_TOKENS - len(row))]
1817
+ for row in answer_sequence_rows
1818
+ ]
1819
+ padded_answer_sequence_prompts = [
1820
+ row + [-1 for _ in range(MAX_ANSWER_SEQUENCE_TOKENS - len(row))]
1821
+ for row in answer_sequence_prompt_rows
1822
+ ]
1823
+ model.answer_sequence_prompt_tokens = padded_answer_sequence_prompts
1824
+ model.answer_sequence_tokens = padded_answer_sequences
1825
+ model.transition_tables = transitions.finalize(
1826
+ max_contexts_per_order=config.max_transition_contexts_per_order,
1827
+ max_next_tokens=config.max_transition_next_tokens,
1828
+ )
1829
+ finish_stage("model_finalize")
1830
+
1831
+ payload = {
1832
+ "streaming": True,
1833
+ "documents_processed": processed,
1834
+ "source_counts": source_counts,
1835
+ "embedding_vocab_size": len(embedding_model.id_to_token),
1836
+ "tokenizer_vocab_size": tokenizer.vocab_size,
1837
+ "examples_processed": int(round(example_weight_total)),
1838
+ "associative_examples": len(model.associative_keys),
1839
+ "answer_associative_examples": len(answer_reservoir.states),
1840
+ "general_associative_examples": len(general_reservoir.states),
1841
+ "answer_intent_examples": len(model.answer_keys),
1842
+ "answer_start_examples": len(model.answer_start_keys),
1843
+ "answer_sequence_examples": len(model.answer_sequence_keys),
1844
+ "prompt_answer_readout_examples": prompt_answer_readout_examples,
1845
+ "prompt_answer_start_readout_examples": prompt_answer_start_readout_examples,
1846
+ "stage_seconds": stage_seconds,
1847
+ "target_balance_reference": round(float(reference_label_mass), 6),
1848
+ "readout_solver": readout_solver,
1849
+ "preference_pairs": len(preference_token_pairs),
1850
+ "preference_state_pairs": preference_state_pairs,
1851
+ }
1852
+ return model, payload
reframr/ternary.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from .linalg import Vector, mean
4
+
5
+
6
+ def quantize_vector_absmean(
7
+ values: Vector,
8
+ *,
9
+ threshold: float = 0.5,
10
+ ) -> tuple[float, list[int]]:
11
+ if not values:
12
+ return 1.0, []
13
+
14
+ scale = mean([abs(value) for value in values])
15
+ if scale == 0.0:
16
+ return 1.0, [0 for _ in values]
17
+
18
+ quantized: list[int] = []
19
+ for value in values:
20
+ normalized = value / scale
21
+ if normalized >= threshold:
22
+ quantized.append(1)
23
+ elif normalized <= -threshold:
24
+ quantized.append(-1)
25
+ else:
26
+ quantized.append(0)
27
+ return scale, quantized
28
+
29
+
30
+ def derive_ternary_mask_from_states(states: list[Vector]) -> tuple[float, list[int]]:
31
+ if not states:
32
+ return 1.0, []
33
+ feature_count = len(states[0])
34
+ feature_energy = [
35
+ mean([state[feature] * state[feature] for state in states])
36
+ for feature in range(feature_count)
37
+ ]
38
+ return derive_ternary_mask_from_feature_energy(feature_energy)
39
+
40
+
41
+ def derive_ternary_mask_from_feature_energy(
42
+ feature_energy: Vector,
43
+ *,
44
+ threshold: float = 0.02,
45
+ ) -> tuple[float, list[int]]:
46
+ if not feature_energy:
47
+ return 1.0, []
48
+
49
+ rms_values = [math.sqrt(max(value, 0.0)) for value in feature_energy]
50
+ scale = mean(rms_values)
51
+ if scale == 0.0:
52
+ return 1.0, [0 for _ in feature_energy]
53
+
54
+ mask = [1 if value >= threshold * scale else 0 for value in rms_values]
55
+ if not any(mask):
56
+ mask = [1 for _ in feature_energy]
57
+ return 1.0, mask
58
+
59
+
60
+ def apply_ternary_mask(values: Vector, mask: list[int], scale: float) -> Vector:
61
+ if not mask:
62
+ return values[:]
63
+ return [scale * mask[index] * values[index] for index in range(len(values))]
reframr/text_quality.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ REFRAMR_NAME_PATTERN = re.compile(r"\breframr\b", re.IGNORECASE)
5
+ LINE_ROLE_PREFIX_PATTERN = re.compile(
6
+ r"(?im)^\s*(?:user|assistant|human|system|bot|model|gpt)\s*:\s*"
7
+ )
8
+ STRUCTURAL_ROLE_PREFIX_PATTERN = re.compile(
9
+ r"(?i)(<(?:reason|answer)>\s+)(?:user|assistant|human|system|bot|model|gpt)\s*:\s*"
10
+ )
11
+ SYSTEM_SCAFFOLD_LINE_PATTERN = re.compile(
12
+ r"(?i)^\s*(?:"
13
+ r"you\s+are\s+(?:an?\s+)?(?:helpful\s+)?(?:ai\s+)?assistant\b.*|"
14
+ r"your\s+role\s+as\s+an\s+assistant\s+involves\b.*|"
15
+ r"you\s+will\s+be\s+given\s+a\s+task\b.*|"
16
+ r"your\s+goal\s+is\s+to\s+complete\s+the\s+task\b.*|"
17
+ r"you\s+must\s+generate\s+a\s+detailed\s+and\s+long\s+answer\b.*|"
18
+ r"please\s+structure\s+your\s+response\s+into\s+two\s+main\s+sections\b.*|"
19
+ r"in\s+the\s+thought\s+section\b.*|"
20
+ r"in\s+the\s+solution\s+section\b.*|"
21
+ r"now,\s*try\s+to\s+solve\s+the\s+following\s+question\b.*|"
22
+ r"while\s+answering\s+think\s+step\s*[- ]?\s*by\s*[- ]?\s*step\b.*|"
23
+ r"think\s+like\s+you\s+are\s+answering\b.*"
24
+ r")\s*$"
25
+ )
26
+ OPEN_SOLUTION_PATTERN = re.compile(
27
+ r"(?is)<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>"
28
+ )
29
+ OPEN_THOUGHT_PATTERN = re.compile(
30
+ r"(?is)<\|begin_of_thought\|>.*?<\|end_of_thought\|>"
31
+ )
32
+ OPEN_TAG_PATTERN = re.compile(r"(?is)<\|[^>]+?\|>")
33
+ LEADING_ASSISTANT_FILLER_PATTERN = re.compile(
34
+ r"(?is)^\s*(?:sure(?:\s+thing)?|certainly|absolutely|of\s+course|yes)\s*[!,.:-]*\s+"
35
+ )
36
+ MOJIBAKE_MARKERS = ("â", "Ã", "Â", "â", "Ã", "Â")
37
+
38
+
39
+ def canonicalize_reframr_name(text: str) -> str:
40
+ return REFRAMR_NAME_PATTERN.sub("Reframr", text)
41
+
42
+
43
+ def repair_common_mojibake(text: str) -> str:
44
+ repaired = text
45
+ for _ in range(3):
46
+ if not any(marker in repaired for marker in MOJIBAKE_MARKERS):
47
+ break
48
+ original_markers = sum(repaired.count(marker) for marker in MOJIBAKE_MARKERS)
49
+ best = repaired
50
+ best_markers = original_markers
51
+ for encoding in ("cp1252", "latin1"):
52
+ try:
53
+ candidate = repaired.encode(encoding).decode("utf-8")
54
+ except UnicodeError:
55
+ continue
56
+ candidate_markers = sum(candidate.count(marker) for marker in MOJIBAKE_MARKERS)
57
+ if candidate_markers < best_markers:
58
+ best = candidate
59
+ best_markers = candidate_markers
60
+ if best == repaired:
61
+ break
62
+ repaired = best
63
+ return repaired
64
+
65
+
66
+ def strip_role_prefixes(text: str) -> str:
67
+ cleaned = STRUCTURAL_ROLE_PREFIX_PATTERN.sub(r"\1", text)
68
+ return LINE_ROLE_PREFIX_PATTERN.sub("", cleaned).strip()
69
+
70
+
71
+ def strip_instruction_scaffold(text: str) -> str:
72
+ lines = []
73
+ for line in text.splitlines():
74
+ if SYSTEM_SCAFFOLD_LINE_PATTERN.match(line):
75
+ continue
76
+ lines.append(line)
77
+ return "\n".join(lines).strip()
78
+
79
+
80
+ def clean_training_text(text: str) -> str:
81
+ repaired = repair_common_mojibake(text)
82
+ return strip_role_prefixes(canonicalize_reframr_name(repaired)).strip()
83
+
84
+
85
+ def clean_context_text(text: str) -> str:
86
+ return strip_instruction_scaffold(clean_training_text(text))
87
+
88
+
89
+ def clean_answer_text(text: str) -> str:
90
+ cleaned = clean_training_text(text)
91
+ solution_match = OPEN_SOLUTION_PATTERN.search(cleaned)
92
+ if solution_match:
93
+ cleaned = solution_match.group(1)
94
+ else:
95
+ cleaned = OPEN_THOUGHT_PATTERN.sub("", cleaned)
96
+ cleaned = OPEN_TAG_PATTERN.sub("", cleaned)
97
+ cleaned = LEADING_ASSISTANT_FILLER_PATTERN.sub("", cleaned)
98
+ return cleaned.strip()
reframr/tokenizer.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import unicodedata
3
+ from collections import Counter
4
+ from collections.abc import Mapping
5
+ from dataclasses import dataclass, field
6
+ from string import ascii_letters, digits
7
+
8
+ from .reasoning import REASONING_CONTROL_TOKENS, TOKENIZER_NAME
9
+
10
+ PRETOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE)
11
+ BYTE_FALLBACK_PATTERN = re.compile(r"<byte:([0-9A-F]{2})>")
12
+ DEFAULT_FALLBACK_CHARACTERS = (
13
+ ascii_letters
14
+ + digits
15
+ + "'-_/.:,;!?()[]{}@#$%&*+="
16
+ + "’ʼ‘“”—–…"
17
+ )
18
+ MAX_TOKENIZER_VOCAB_SIZE = 65536
19
+ MAX_SEGMENT_CACHE_SIZE = 200_000
20
+ MAX_TRAINED_PAIR_MERGES = 384
21
+
22
+
23
+ def _is_word_character(character: str) -> bool:
24
+ category = unicodedata.category(character)
25
+ return character == "_" or category[0] in {"L", "N"} or category == "Mn"
26
+
27
+
28
+ def _is_variation_selector(character: str) -> bool:
29
+ return "VARIATION SELECTOR" in unicodedata.name(character, "")
30
+
31
+
32
+ def _is_zero_width_joiner(character: str) -> bool:
33
+ return unicodedata.name(character, "") == "ZERO WIDTH JOINER"
34
+
35
+
36
+ def _is_emoji_modifier(character: str) -> bool:
37
+ return "EMOJI MODIFIER" in unicodedata.name(character, "")
38
+
39
+
40
+ def _is_emoji_base_character(character: str) -> bool:
41
+ name = unicodedata.name(character, "")
42
+ category = unicodedata.category(character)
43
+ return (
44
+ "EMOJI" in name
45
+ or "REGIONAL INDICATOR SYMBOL" in name
46
+ or (category in {"So", "Sk"} and ord(character) >= 0x2100)
47
+ )
48
+
49
+
50
+ def _is_emoji_continuation_character(character: str) -> bool:
51
+ category = unicodedata.category(character)
52
+ name = unicodedata.name(character, "")
53
+ return (
54
+ _is_variation_selector(character)
55
+ or _is_zero_width_joiner(character)
56
+ or _is_emoji_modifier(character)
57
+ or category in {"Mn", "Me"}
58
+ or name.startswith("TAG ")
59
+ )
60
+
61
+
62
+ def _consume_emoji_cluster(text: str, start: int) -> int:
63
+ if start >= len(text) or not _is_emoji_base_character(text[start]):
64
+ return start
65
+
66
+ index = start + 1
67
+ if "REGIONAL INDICATOR SYMBOL" in unicodedata.name(text[start], ""):
68
+ if index < len(text) and "REGIONAL INDICATOR SYMBOL" in unicodedata.name(text[index], ""):
69
+ return index + 1
70
+ return index
71
+
72
+ while index < len(text):
73
+ if _is_emoji_continuation_character(text[index]):
74
+ index += 1
75
+ continue
76
+ if _is_zero_width_joiner(text[index - 1]) and _is_emoji_base_character(text[index]):
77
+ index += 1
78
+ continue
79
+ break
80
+ return index
81
+
82
+
83
+ def _byte_token(value: int) -> str:
84
+ return f"<byte:{value:02X}>"
85
+
86
+
87
+ def _byte_value(piece: str) -> int | None:
88
+ match = BYTE_FALLBACK_PATTERN.fullmatch(piece)
89
+ if match is None:
90
+ return None
91
+ return int(match.group(1), 16)
92
+
93
+
94
+ def _is_punctuation_piece(piece: str) -> bool:
95
+ return bool(piece) and all(
96
+ unicodedata.category(character).startswith("P")
97
+ for character in piece
98
+ )
99
+
100
+
101
+ def _is_opening_punctuation(piece: str) -> bool:
102
+ return bool(piece) and all(
103
+ unicodedata.category(character) in {"Ps", "Pi"}
104
+ for character in piece
105
+ )
106
+
107
+
108
+ def _is_call_opening_punctuation(piece: str) -> bool:
109
+ return bool(piece) and all(
110
+ unicodedata.category(character) == "Ps"
111
+ and "PARENTHESIS" in unicodedata.name(character, "")
112
+ for character in piece
113
+ )
114
+
115
+
116
+ def _is_closing_or_terminal_punctuation(piece: str) -> bool:
117
+ return bool(piece) and all(
118
+ unicodedata.category(character) in {"Pe", "Pf", "Po"}
119
+ for character in piece
120
+ )
121
+
122
+
123
+ def _is_infix_joiner(piece: str) -> bool:
124
+ if len(piece) != 1:
125
+ return False
126
+ category = unicodedata.category(piece)
127
+ name = unicodedata.name(piece, "")
128
+ return (
129
+ category == "Pd"
130
+ or "APOSTROPHE" in name
131
+ or (category == "Pf" and "SINGLE QUOTATION MARK" in name)
132
+ or "SOLIDUS" in name
133
+ )
134
+
135
+
136
+ def _is_dash_joiner(piece: str) -> bool:
137
+ if len(piece) != 1:
138
+ return False
139
+ category = unicodedata.category(piece)
140
+ name = unicodedata.name(piece, "")
141
+ return category == "Pd" or "HYPHEN" in name or "DASH" in name
142
+
143
+
144
+ def _is_quote_piece(piece: str) -> bool:
145
+ if len(piece) != 1:
146
+ return False
147
+ if _is_infix_joiner(piece):
148
+ return False
149
+ name = unicodedata.name(piece, "")
150
+ category = unicodedata.category(piece)
151
+ return "QUOTATION MARK" in name or category in {"Pi", "Pf"}
152
+
153
+
154
+ def _merge_symbol(left: str, right: str, prefix: str) -> str:
155
+ if right.startswith(prefix):
156
+ return left + right[len(prefix):]
157
+ return left + right
158
+
159
+
160
+ def _merge_sequence(symbols: list[str], pair: tuple[str, str], merged_symbol: str) -> list[str]:
161
+ merged: list[str] = []
162
+ index = 0
163
+ while index < len(symbols):
164
+ if index < len(symbols) - 1 and (symbols[index], symbols[index + 1]) == pair:
165
+ merged.append(merged_symbol)
166
+ index += 2
167
+ else:
168
+ merged.append(symbols[index])
169
+ index += 1
170
+ return merged
171
+
172
+
173
+ def _default_symbol_inventory(word_prefix: str) -> set[str]:
174
+ symbols: set[str] = set()
175
+ for character in DEFAULT_FALLBACK_CHARACTERS:
176
+ symbols.add(character)
177
+ symbols.add(f"{word_prefix}{character}")
178
+ for value in range(256):
179
+ token = _byte_token(value)
180
+ symbols.add(token)
181
+ symbols.add(f"{word_prefix}{token}")
182
+ return symbols
183
+
184
+
185
+ def _whole_segment_token(segment: str, word_prefix: str) -> str:
186
+ return f"{word_prefix}{segment}"
187
+
188
+
189
+ def recommend_vocab_size(
190
+ text: str,
191
+ *,
192
+ minimum: int = 768,
193
+ maximum: int = 1536,
194
+ multiplier: int = 5,
195
+ lowercase: bool = False,
196
+ ) -> int:
197
+ seed_tokenizer = NativeTokenizer(
198
+ merges=[],
199
+ vocab=[],
200
+ base_symbols=[],
201
+ lowercase=lowercase,
202
+ )
203
+ segments = seed_tokenizer.pretokenize(text)
204
+ distinct_segments = len(set(segments))
205
+ recommended = max(minimum, distinct_segments * multiplier)
206
+ return min(maximum, recommended)
207
+
208
+
209
+ def clamp_vocab_size(requested: int, *, maximum: int = MAX_TOKENIZER_VOCAB_SIZE) -> int:
210
+ return min(maximum, max(1, requested))
211
+
212
+
213
+ @dataclass(slots=True)
214
+ class NativeTokenizer:
215
+ merges: list[tuple[str, str]]
216
+ vocab: list[str]
217
+ base_symbols: list[str]
218
+ name: str = TOKENIZER_NAME
219
+ lowercase: bool = False
220
+ word_prefix: str = "▁"
221
+ unk_token: str = "<unk>"
222
+ bos_token: str = "<bos>"
223
+ eos_token: str = "<eos>"
224
+ pad_token: str = "<pad>"
225
+ _merge_ranks: dict[tuple[str, str], int] = field(init=False, repr=False)
226
+ _vocab_set: set[str] = field(init=False, repr=False)
227
+ _base_symbol_set: set[str] = field(init=False, repr=False)
228
+ _pretoken_pattern: re.Pattern[str] = field(init=False, repr=False)
229
+ _segment_cache: dict[str, tuple[str, ...]] = field(init=False, repr=False)
230
+
231
+ def __post_init__(self) -> None:
232
+ self._merge_ranks = {pair: index for index, pair in enumerate(self.merges)}
233
+ self._base_symbol_set = set(self.base_symbols)
234
+ self._vocab_set = set(self.vocab) | self.special_tokens | self._base_symbol_set
235
+ self.vocab = sorted(self._vocab_set)
236
+ self._pretoken_pattern = self._build_pretoken_pattern()
237
+ self._segment_cache = {}
238
+
239
+ @property
240
+ def special_tokens(self) -> set[str]:
241
+ return {
242
+ self.unk_token,
243
+ self.bos_token,
244
+ self.eos_token,
245
+ self.pad_token,
246
+ *REASONING_CONTROL_TOKENS,
247
+ }
248
+
249
+ @property
250
+ def vocab_size(self) -> int:
251
+ return len(self._vocab_set)
252
+
253
+ def normalize(self, text: str) -> str:
254
+ normalized = unicodedata.normalize("NFKC", text)
255
+ return normalized.lower() if self.lowercase else normalized
256
+
257
+ def pretokenize(self, text: str) -> list[str]:
258
+ normalized = self.normalize(text)
259
+ segments: list[str] = []
260
+ reserved = sorted(self.special_tokens, key=len, reverse=True)
261
+ index = 0
262
+ while index < len(normalized):
263
+ if normalized[index].isspace():
264
+ if normalized[index] == "\r":
265
+ if index + 1 < len(normalized) and normalized[index + 1] == "\n":
266
+ segments.append("\n")
267
+ index += 2
268
+ continue
269
+ segments.append("\n")
270
+ index += 1
271
+ continue
272
+ if normalized[index] == "\n":
273
+ segments.append("\n")
274
+ index += 1
275
+ continue
276
+ index += 1
277
+ continue
278
+
279
+ matched_special = next(
280
+ (
281
+ token
282
+ for token in reserved
283
+ if normalized.startswith(token, index)
284
+ ),
285
+ None,
286
+ )
287
+ if matched_special is not None:
288
+ segments.append(matched_special)
289
+ index += len(matched_special)
290
+ continue
291
+
292
+ emoji_end = _consume_emoji_cluster(normalized, index)
293
+ if emoji_end > index:
294
+ segments.append(normalized[index:emoji_end])
295
+ index = emoji_end
296
+ continue
297
+
298
+ if _is_word_character(normalized[index]):
299
+ start = index
300
+ index += 1
301
+ while index < len(normalized) and _is_word_character(normalized[index]):
302
+ index += 1
303
+ segments.append(normalized[start:index])
304
+ continue
305
+
306
+ segments.append(normalized[index])
307
+ index += 1
308
+ return segments
309
+
310
+ def encode(self, text: str, *, add_special_tokens: bool = False) -> list[str]:
311
+ tokens: list[str] = []
312
+ if add_special_tokens:
313
+ tokens.append(self.bos_token)
314
+
315
+ for segment in self.pretokenize(text):
316
+ tokens.extend(self._encode_segment_cached(segment))
317
+
318
+ if add_special_tokens:
319
+ tokens.append(self.eos_token)
320
+
321
+ if not tokens and text.strip():
322
+ return [self.unk_token]
323
+ return tokens
324
+
325
+ def encode_many(
326
+ self,
327
+ texts: list[str] | tuple[str, ...],
328
+ *,
329
+ add_special_tokens: bool = False,
330
+ ) -> list[list[str]]:
331
+ return [
332
+ self.encode(text, add_special_tokens=add_special_tokens)
333
+ for text in texts
334
+ ]
335
+
336
+ def decode(self, tokens: list[str]) -> str:
337
+ text = ""
338
+ join_next = False
339
+ byte_buffer = bytearray()
340
+ byte_starts_segment = False
341
+
342
+ def next_rendered_piece(start_index: int) -> str | None:
343
+ for raw_token in tokens[start_index:]:
344
+ if raw_token in self.special_tokens:
345
+ continue
346
+ raw_starts_segment = raw_token.startswith(self.word_prefix)
347
+ raw_piece = raw_token[len(self.word_prefix) :] if raw_starts_segment else raw_token
348
+ if not raw_piece:
349
+ continue
350
+ if _byte_value(raw_piece) is not None:
351
+ return None
352
+ return raw_piece
353
+ return None
354
+
355
+ def append_piece(piece: str, starts_segment: bool, next_piece: str | None = None) -> None:
356
+ nonlocal text, join_next
357
+
358
+ if piece == "\n":
359
+ text = text.rstrip(" ")
360
+ text += "\n"
361
+ join_next = True
362
+ return
363
+
364
+ had_text_before_piece = bool(text.strip())
365
+ previous_before_piece = text.rstrip(" ")[-1:] if text.strip(" ") else ""
366
+ if _is_quote_piece(piece):
367
+ quote_count = sum(1 for character in text if _is_quote_piece(character))
368
+ opens_quote = quote_count % 2 == 0
369
+ if opens_quote:
370
+ if text and not text.endswith((" ", "\n")) and previous_before_piece not in {"(", "[", "{"}:
371
+ text += " "
372
+ text += piece
373
+ join_next = True
374
+ return
375
+ text = text.rstrip(" ")
376
+ text += piece
377
+ join_next = False
378
+ return
379
+
380
+ attaches_left = _is_closing_or_terminal_punctuation(piece) or _is_infix_joiner(piece)
381
+ continues_segment = (not starts_segment) and any(
382
+ _is_word_character(character) or _is_emoji_continuation_character(character)
383
+ for character in piece
384
+ )
385
+ if starts_segment:
386
+ if text and not join_next:
387
+ attaches_to_previous_code_span = (
388
+ _is_opening_punctuation(piece)
389
+ and previous_before_piece.isalnum()
390
+ and next_piece is not None
391
+ and (
392
+ _is_infix_joiner(next_piece)
393
+ or _is_call_opening_punctuation(piece)
394
+ )
395
+ )
396
+ if not _is_punctuation_piece(piece) or (
397
+ _is_opening_punctuation(piece)
398
+ and not attaches_to_previous_code_span
399
+ ):
400
+ text += " "
401
+ text += piece
402
+ else:
403
+ if text and not join_next and not attaches_left and not continues_segment:
404
+ text += " "
405
+ text += piece
406
+
407
+ join_next = (
408
+ _is_infix_joiner(piece)
409
+ and (
410
+ not starts_segment
411
+ or (
412
+ had_text_before_piece
413
+ and (
414
+ not _is_dash_joiner(piece)
415
+ or previous_before_piece.isalnum()
416
+ or _is_opening_punctuation(previous_before_piece)
417
+ )
418
+ )
419
+ )
420
+ ) or _is_opening_punctuation(piece)
421
+
422
+ def flush_bytes() -> None:
423
+ nonlocal byte_buffer, byte_starts_segment
424
+ if not byte_buffer:
425
+ return
426
+ append_piece(bytes(byte_buffer).decode("utf-8", errors="replace"), byte_starts_segment)
427
+ byte_buffer = bytearray()
428
+ byte_starts_segment = False
429
+
430
+ for token_index, token in enumerate(tokens):
431
+ if token in self.special_tokens:
432
+ continue
433
+ starts_segment = token.startswith(self.word_prefix)
434
+ piece = token[len(self.word_prefix) :] if starts_segment else token
435
+ if not piece:
436
+ continue
437
+ byte_value = _byte_value(piece)
438
+ if byte_value is not None:
439
+ if not byte_buffer:
440
+ byte_starts_segment = starts_segment
441
+ byte_buffer.append(byte_value)
442
+ continue
443
+
444
+ flush_bytes()
445
+ append_piece(piece, starts_segment, next_rendered_piece(token_index + 1))
446
+ flush_bytes()
447
+ return text.strip()
448
+
449
+ def _encode_segment_cached(self, segment: str) -> tuple[str, ...]:
450
+ cached = self._segment_cache.get(segment)
451
+ if cached is not None:
452
+ return cached
453
+ encoded = tuple(self._encode_segment(segment))
454
+ if len(self._segment_cache) < MAX_SEGMENT_CACHE_SIZE:
455
+ self._segment_cache[segment] = encoded
456
+ return encoded
457
+
458
+ def _encode_segment(self, segment: str) -> list[str]:
459
+ if segment in self.special_tokens:
460
+ return [segment]
461
+ whole_segment = _whole_segment_token(segment, self.word_prefix)
462
+ if whole_segment in self._vocab_set:
463
+ return [whole_segment]
464
+ symbols = self._seed_symbols(segment)
465
+ if not symbols:
466
+ return []
467
+
468
+ while len(symbols) > 1:
469
+ best_rank: int | None = None
470
+ best_pair: tuple[str, str] | None = None
471
+ for index in range(len(symbols) - 1):
472
+ pair = (symbols[index], symbols[index + 1])
473
+ rank = self._merge_ranks.get(pair)
474
+ if rank is None:
475
+ continue
476
+ if best_rank is None or rank < best_rank:
477
+ best_rank = rank
478
+ best_pair = pair
479
+ if best_pair is None:
480
+ break
481
+
482
+ merged_symbol = _merge_symbol(best_pair[0], best_pair[1], self.word_prefix)
483
+ symbols = _merge_sequence(symbols, best_pair, merged_symbol)
484
+
485
+ if any(symbol not in self._vocab_set for symbol in symbols):
486
+ return [self.unk_token]
487
+ return symbols
488
+
489
+ def _seed_symbols(self, segment: str) -> list[str]:
490
+ symbols: list[str] = []
491
+ for index, character in enumerate(segment):
492
+ symbol = f"{self.word_prefix}{character}" if index == 0 else character
493
+ if symbol in self._base_symbol_set:
494
+ symbols.append(symbol)
495
+ continue
496
+
497
+ encoded = character.encode("utf-8")
498
+ for byte_index, value in enumerate(encoded):
499
+ token = _byte_token(value)
500
+ if index == 0 and byte_index == 0:
501
+ token = f"{self.word_prefix}{token}"
502
+ symbols.append(token)
503
+
504
+ if any(symbol not in self._base_symbol_set for symbol in symbols):
505
+ return [self.unk_token]
506
+ return symbols
507
+
508
+ def to_dict(self) -> dict[str, object]:
509
+ return {
510
+ "name": self.name,
511
+ "merges": [[left, right] for left, right in self.merges],
512
+ "vocab": self.vocab,
513
+ "base_symbols": self.base_symbols,
514
+ "lowercase": self.lowercase,
515
+ "word_prefix": self.word_prefix,
516
+ "unk_token": self.unk_token,
517
+ "bos_token": self.bos_token,
518
+ "eos_token": self.eos_token,
519
+ "pad_token": self.pad_token,
520
+ }
521
+
522
+ @classmethod
523
+ def from_dict(cls, payload: dict[str, object]) -> "NativeTokenizer":
524
+ return cls(
525
+ merges=[(str(left), str(right)) for left, right in payload["merges"]],
526
+ vocab=[str(token) for token in payload["vocab"]],
527
+ base_symbols=[str(token) for token in payload["base_symbols"]],
528
+ name=str(payload.get("name", TOKENIZER_NAME)),
529
+ lowercase=bool(payload["lowercase"]),
530
+ word_prefix=str(payload["word_prefix"]),
531
+ unk_token=str(payload["unk_token"]),
532
+ bos_token=str(payload["bos_token"]),
533
+ eos_token=str(payload["eos_token"]),
534
+ pad_token=str(payload["pad_token"]),
535
+ )
536
+
537
+ def _build_pretoken_pattern(self) -> re.Pattern[str]:
538
+ reserved = sorted(self.special_tokens, key=len, reverse=True)
539
+ if not reserved:
540
+ return PRETOKEN_PATTERN
541
+ reserved_pattern = "|".join(re.escape(token) for token in reserved)
542
+ return re.compile(f"{reserved_pattern}|\\w+|[^\\w\\s]", re.UNICODE)
543
+
544
+ @classmethod
545
+ def train(
546
+ cls,
547
+ text: str,
548
+ *,
549
+ vocab_size: int = 256,
550
+ min_pair_frequency: int = 2,
551
+ lowercase: bool = False,
552
+ word_prefix: str = "▁",
553
+ ) -> "NativeTokenizer":
554
+ seed_tokenizer = cls(
555
+ merges=[],
556
+ vocab=[],
557
+ base_symbols=[],
558
+ lowercase=lowercase,
559
+ word_prefix=word_prefix,
560
+ )
561
+ segments = seed_tokenizer.pretokenize(text)
562
+ if not segments:
563
+ raise ValueError("Cannot train the native tokenizer on empty text.")
564
+
565
+ return cls.train_from_segment_counts(
566
+ Counter(segments),
567
+ vocab_size=vocab_size,
568
+ min_pair_frequency=min_pair_frequency,
569
+ lowercase=lowercase,
570
+ word_prefix=word_prefix,
571
+ )
572
+
573
+ @classmethod
574
+ def train_from_segment_counts(
575
+ cls,
576
+ segment_counts: Mapping[str, float],
577
+ *,
578
+ vocab_size: int = 256,
579
+ min_pair_frequency: int = 2,
580
+ lowercase: bool = False,
581
+ word_prefix: str = "▁",
582
+ ) -> "NativeTokenizer":
583
+ if not segment_counts:
584
+ raise ValueError("Cannot train the native tokenizer on empty segment counts.")
585
+ seed_tokenizer = cls(
586
+ merges=[],
587
+ vocab=[],
588
+ base_symbols=[],
589
+ lowercase=lowercase,
590
+ word_prefix=word_prefix,
591
+ )
592
+
593
+ word_counts = Counter(
594
+ {
595
+ str(segment): float(frequency)
596
+ for segment, frequency in segment_counts.items()
597
+ if str(segment) and float(frequency) > 0.0
598
+ }
599
+ )
600
+ if not word_counts:
601
+ raise ValueError("Cannot train the native tokenizer on empty segment counts.")
602
+ observed_symbols = {
603
+ f"{word_prefix}{character}" if index == 0 else character
604
+ for segment in word_counts
605
+ for index, character in enumerate(segment)
606
+ }
607
+ base_symbols = _default_symbol_inventory(word_prefix)
608
+ base_symbols.update(observed_symbols)
609
+ sequences = {
610
+ segment: [
611
+ f"{word_prefix}{character}" if index == 0 else character
612
+ for index, character in enumerate(segment)
613
+ ]
614
+ for segment in word_counts
615
+ }
616
+ vocab = set(observed_symbols) | seed_tokenizer.special_tokens
617
+ target_vocab_size = len(vocab) + max(1, vocab_size)
618
+ segment_candidates = sorted(
619
+ {
620
+ segment
621
+ for segment, frequency in word_counts.items()
622
+ if len(segment) > 1 and frequency >= min_pair_frequency
623
+ },
624
+ key=lambda segment: (
625
+ -(word_counts[segment] * len(segment)),
626
+ -len(segment),
627
+ segment,
628
+ ),
629
+ )
630
+ for segment in segment_candidates:
631
+ if len(vocab) >= target_vocab_size:
632
+ break
633
+ vocab.add(_whole_segment_token(segment, word_prefix))
634
+ merges: list[tuple[str, str]] = []
635
+
636
+ while len(vocab) < target_vocab_size and len(merges) < MAX_TRAINED_PAIR_MERGES:
637
+ pair_counts: Counter[tuple[str, str]] = Counter()
638
+ for segment, frequency in word_counts.items():
639
+ symbols = sequences[segment]
640
+ for index in range(len(symbols) - 1):
641
+ pair_counts[(symbols[index], symbols[index + 1])] += frequency
642
+
643
+ if not pair_counts:
644
+ break
645
+
646
+ best_pair, best_count = min(
647
+ pair_counts.items(),
648
+ key=lambda item: (-item[1], item[0][0], item[0][1]),
649
+ )
650
+ if best_count < min_pair_frequency:
651
+ break
652
+
653
+ merged_symbol = _merge_symbol(best_pair[0], best_pair[1], word_prefix)
654
+ merges.append(best_pair)
655
+ vocab.add(merged_symbol)
656
+ for segment in sequences:
657
+ sequences[segment] = _merge_sequence(sequences[segment], best_pair, merged_symbol)
658
+
659
+ return cls(
660
+ merges=merges,
661
+ vocab=sorted(vocab),
662
+ base_symbols=sorted(base_symbols),
663
+ lowercase=lowercase,
664
+ word_prefix=word_prefix,
665
+ )
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy>=2.1,<3
2
+ scipy>=1.14,<2
3
+ datasets>=4.1,<5
sample_prompts.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"prompt":"Who are you, and what makes Reframr different from Transformer models?","max_tokens":90,"temperature":0.92}
2
+ {"system":"Answer with calm confidence and no hype.","prompt":"Explain why computed weights are different from memorized template responses.","max_tokens":100,"temperature":0.9}
3
+ {"prompt":"Tell a compact story about a city that stores its memories in rainwater.","max_tokens":120,"temperature":1.05,"decode_top_k":90}
4
+ {"system":"Use exactly one fitting emoji.","prompt":"Write a warm note to a teammate who fixed a hard bug.","max_tokens":70,"temperature":0.95}
5
+ {"prompt":"Give safe, defensive guidance for recognizing a phishing email without helping an attacker.","max_tokens":100,"temperature":0.88}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff