Release Reframr-RFM-v1-Base public checkpoint
Browse filesPublic v1 base release for Reframr RFM. Internal provenance: v95 computed checkpoint. Includes model.safetensors, tokenizer, runtime source, config, generation examples, and model card.
- README.md +141 -0
- config.json +107 -0
- examples/jsonl_serve.ps1 +7 -0
- examples/python_inference.py +44 -0
- generation_config.json +8 -0
- model.safetensors +3 -0
- pyproject.toml +17 -0
- reframr/__init__.py +32 -0
- reframr/__main__.py +5 -0
- reframr/checkpoint.py +274 -0
- reframr/cli.py +760 -0
- reframr/config.py +68 -0
- reframr/corpus.py +123 -0
- reframr/corpus_recipes.py +1257 -0
- reframr/curriculum.py +0 -0
- reframr/datasets.py +165 -0
- reframr/embeddings.py +457 -0
- reframr/evaluation.py +265 -0
- reframr/hf_import.py +662 -0
- reframr/hippo.py +145 -0
- reframr/linalg.py +271 -0
- reframr/model.py +0 -0
- reframr/reasoning.py +26 -0
- reframr/reservoir.py +94 -0
- reframr/streaming.py +1852 -0
- reframr/ternary.py +63 -0
- reframr/text_quality.py +98 -0
- reframr/tokenizer.py +665 -0
- requirements.txt +3 -0
- sample_prompts.jsonl +5 -0
- tokenizer.json +0 -0
README.md
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
tags:
|
| 5 |
+
- reframr
|
| 6 |
+
- okeymeta
|
| 7 |
+
- non-transformer
|
| 8 |
+
- recurrent-memory
|
| 9 |
+
- computed-weights
|
| 10 |
+
- cpu-inference
|
| 11 |
+
- safetensors
|
| 12 |
+
library_name: reframr
|
| 13 |
+
pipeline_tag: text-generation
|
| 14 |
+
license: other
|
| 15 |
+
base_model: scratch
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
# Reframr-RFM-v1-Base
|
| 19 |
+
|
| 20 |
+
**Reframr-RFM-v1-Base** is the first public base checkpoint from **OkeyMeta Ltd** for the Reframr line of non-Transformer language models. Reframr is built from scratch around recurrent memory, computed weights, and data-derived structure rather than a Transformer attention stack.
|
| 21 |
+
|
| 22 |
+
This release is packaged as `model.safetensors` with the matching `tokenizer.json`, runtime source, config, and runnable examples. A larger production Reframr line is being computed after this release, including tool-use and web-freshness data.
|
| 23 |
+
|
| 24 |
+
## What It Is
|
| 25 |
+
|
| 26 |
+
Reframr-RFM means **Recurrent Flow Memory**. The model is designed around a persistent recurrent state instead of a fixed quadratic attention map. That gives the architecture no fixed attention-window context limit; practical limits are determined by runtime session length, machine memory, and deployment policy.
|
| 27 |
+
|
| 28 |
+
This checkpoint is not a Transformer, not a fine-tuned clone of a Transformer, and not a prompt wrapper. It uses the Reframr runtime included in this repository and a checkpoint kind of `reframr-analytical`.
|
| 29 |
+
|
| 30 |
+
## Model Files
|
| 31 |
+
|
| 32 |
+
- `model.safetensors`: Reframr v1 computed-weight checkpoint.
|
| 33 |
+
- `tokenizer.json`: FrameToken tokenizer exported from the checkpoint metadata.
|
| 34 |
+
- `config.json`: Release metadata and tensor layout.
|
| 35 |
+
- `generation_config.json`: Recommended default generation settings.
|
| 36 |
+
- `reframr/`: CPU-first Reframr runtime source.
|
| 37 |
+
- `examples/`: Minimal CLI, JSONL, and Python usage examples.
|
| 38 |
+
|
| 39 |
+
## Quick Start
|
| 40 |
+
|
| 41 |
+
Use Python 3.13 or newer from the root of this model repository:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
python -m pip install -r requirements.txt
|
| 45 |
+
python -m reframr generate \
|
| 46 |
+
--model model.safetensors \
|
| 47 |
+
--context "Who are you, and what makes you different from Transformer models?" \
|
| 48 |
+
--max-tokens 90 \
|
| 49 |
+
--temperature 0.92 \
|
| 50 |
+
--decode-top-k 72 \
|
| 51 |
+
--decode-top-p 0.92
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
System instructions are passed as learned context:
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
python -m reframr generate \
|
| 58 |
+
--model model.safetensors \
|
| 59 |
+
--system "Answer in two short paragraphs. Be direct and warm." \
|
| 60 |
+
--context "Explain why clean data matters when computing Reframr weights." \
|
| 61 |
+
--max-tokens 90 \
|
| 62 |
+
--temperature 0.9
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
For a persistent process that loads the checkpoint once and accepts JSONL requests:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
python -m reframr serve --model model.safetensors --max-tokens 96
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Then send one JSON object per line:
|
| 72 |
+
|
| 73 |
+
```jsonl
|
| 74 |
+
{"prompt":"Tell a short story about a glass library under the sea.","temperature":1.05,"decode_top_k":90,"max_tokens":120}
|
| 75 |
+
{"system":"Use exactly one fitting emoji.","prompt":"Encourage a tired engineer without sounding generic.","max_tokens":70}
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## Python Example
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
from pathlib import Path
|
| 82 |
+
from reframr.model import ReframrModel
|
| 83 |
+
|
| 84 |
+
root = Path(__file__).resolve().parent
|
| 85 |
+
model = ReframrModel.load(root / "model.safetensors")
|
| 86 |
+
|
| 87 |
+
text = model.generate_text(
|
| 88 |
+
"Who are you?",
|
| 89 |
+
max_tokens=80,
|
| 90 |
+
temperature=0.92,
|
| 91 |
+
top_k=72,
|
| 92 |
+
top_p=0.92,
|
| 93 |
+
repetition_penalty=1.18,
|
| 94 |
+
)
|
| 95 |
+
print(text)
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
## Generation Controls
|
| 99 |
+
|
| 100 |
+
- `temperature`: Higher values increase variation. Try `0.85` for focused answers and `1.05` for story or brainstorming prompts.
|
| 101 |
+
- `--decode-top-k`: Limits sampling to the strongest candidate set. Recommended range: `50` to `100`.
|
| 102 |
+
- `--decode-top-p`: Nucleus cutoff. Recommended default: `0.92`.
|
| 103 |
+
- `--repetition-penalty`: Penalizes repeated tokens. Recommended default: `1.18`.
|
| 104 |
+
- `--system`: Adds a system instruction before the user prompt.
|
| 105 |
+
- `--reasoning-mode`: Supports `none`, `deep`, `memory`, and `tool` profiles in the runtime. The current public checkpoint is a base release; the dedicated tool/web-freshness line is still being computed.
|
| 106 |
+
|
| 107 |
+
## Identity
|
| 108 |
+
|
| 109 |
+
Reframr is built by **OkeyMeta Ltd**. The Reframr line reframes language intelligence around recurrent memory, computed weights, and evidence from data. OkeyMeta Ltd was founded in 2022. The founder and CEO is **Okechukwu Goodnews Nwaozor**.
|
| 110 |
+
|
| 111 |
+
## Architecture Snapshot
|
| 112 |
+
|
| 113 |
+
| Property | Reframr-RFM-v1-Base |
|
| 114 |
+
| --- | --- |
|
| 115 |
+
| Family | Reframr / Recurrent Flow Memory |
|
| 116 |
+
| Organization | OkeyMeta Ltd |
|
| 117 |
+
| Checkpoint kind | `reframr-analytical` |
|
| 118 |
+
| Attention stack | None |
|
| 119 |
+
| Transformer layers | None |
|
| 120 |
+
| Tokenizer | FrameToken |
|
| 121 |
+
| Weight file | `model.safetensors` |
|
| 122 |
+
| Runtime | CPU-first Reframr Python runtime |
|
| 123 |
+
| Embedding dim | 96 |
|
| 124 |
+
| State dim | 48 |
|
| 125 |
+
| State width | 576 |
|
| 126 |
+
| Output vocab rows | 2,793 |
|
| 127 |
+
| Tokenizer vocab size | 3,741 |
|
| 128 |
+
|
| 129 |
+
## Intended Use
|
| 130 |
+
|
| 131 |
+
This checkpoint is intended for public testing of the Reframr runtime, open-ended generation experiments, system-instruction experiments, story generation, safety behavior, identity prompts, and CPU-first research into non-Transformer language modeling.
|
| 132 |
+
|
| 133 |
+
It is a base checkpoint, not a medical, legal, financial, or safety-critical authority. For fresh factual questions, connect a retrieval or web-search tool in the next tool-aware Reframr line rather than relying on static checkpoint knowledge alone.
|
| 134 |
+
|
| 135 |
+
## Release Note
|
| 136 |
+
|
| 137 |
+
This release is the public v1 base checkpoint. Internally, it comes from the v95 tracked compute run; publicly, it begins the Reframr-RFM v1 line. The next production line is being computed with broader data, tool-use supervision, web-search protocol tokens, and larger generalization probes. The goal is simple: make Reframr a serious, CPU-first, non-Transformer model family that learns from data rather than from hardcoded responses.
|
| 138 |
+
|
| 139 |
+
## Ownership
|
| 140 |
+
|
| 141 |
+
Copyright OkeyMeta Ltd. All rights reserved unless a separate license is supplied by OkeyMeta Ltd.
|
config.json
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "reframr-rfm",
|
| 3 |
+
"model_name": "Reframr-RFM-v1-Base",
|
| 4 |
+
"library_name": "reframr",
|
| 5 |
+
"checkpoint_kind": "reframr-analytical",
|
| 6 |
+
"schema_version": "1",
|
| 7 |
+
"architecture": "Reverse-Flow Recurrent Analytical Memory / Recurrent Flow Memory",
|
| 8 |
+
"organization": "OkeyMeta Ltd",
|
| 9 |
+
"creator": "OkeyMeta Ltd",
|
| 10 |
+
"runtime": "CPU-first Reframr Python runtime included in this repository",
|
| 11 |
+
"format": "safetensors",
|
| 12 |
+
"weights_file": "model.safetensors",
|
| 13 |
+
"tokenizer_file": "tokenizer.json",
|
| 14 |
+
"tokenizer_name": "FrameToken",
|
| 15 |
+
"tokenizer_vocab_size": 3741,
|
| 16 |
+
"vocab_size": 2793,
|
| 17 |
+
"embedding_dim": 96,
|
| 18 |
+
"state_dim": 48,
|
| 19 |
+
"state_width": 576,
|
| 20 |
+
"tensor_count": 21,
|
| 21 |
+
"tensor_shapes": {
|
| 22 |
+
"answer_keys": [
|
| 23 |
+
18000,
|
| 24 |
+
576
|
| 25 |
+
],
|
| 26 |
+
"answer_sequence_keys": [
|
| 27 |
+
8400,
|
| 28 |
+
576
|
| 29 |
+
],
|
| 30 |
+
"answer_sequence_prompt_tokens": [
|
| 31 |
+
8400,
|
| 32 |
+
192
|
| 33 |
+
],
|
| 34 |
+
"answer_sequence_tokens": [
|
| 35 |
+
8400,
|
| 36 |
+
192
|
| 37 |
+
],
|
| 38 |
+
"answer_start_keys": [
|
| 39 |
+
18000,
|
| 40 |
+
576
|
| 41 |
+
],
|
| 42 |
+
"answer_start_values": [
|
| 43 |
+
18000
|
| 44 |
+
],
|
| 45 |
+
"answer_values": [
|
| 46 |
+
18000
|
| 47 |
+
],
|
| 48 |
+
"associative_keys": [
|
| 49 |
+
18000,
|
| 50 |
+
576
|
| 51 |
+
],
|
| 52 |
+
"associative_values": [
|
| 53 |
+
18000
|
| 54 |
+
],
|
| 55 |
+
"embedding_table": [
|
| 56 |
+
2793,
|
| 57 |
+
96
|
| 58 |
+
],
|
| 59 |
+
"preference_bias": [
|
| 60 |
+
2793
|
| 61 |
+
],
|
| 62 |
+
"prompt_answer_bias": [
|
| 63 |
+
2793
|
| 64 |
+
],
|
| 65 |
+
"prompt_answer_start_bias": [
|
| 66 |
+
2793
|
| 67 |
+
],
|
| 68 |
+
"prompt_answer_start_weights": [
|
| 69 |
+
2793,
|
| 70 |
+
576
|
| 71 |
+
],
|
| 72 |
+
"prompt_answer_weights": [
|
| 73 |
+
2793,
|
| 74 |
+
576
|
| 75 |
+
],
|
| 76 |
+
"readout_bias": [
|
| 77 |
+
2793
|
| 78 |
+
],
|
| 79 |
+
"readout_weights": [
|
| 80 |
+
2793,
|
| 81 |
+
576
|
| 82 |
+
],
|
| 83 |
+
"state_offset": [
|
| 84 |
+
576
|
| 85 |
+
],
|
| 86 |
+
"ternary_mask": [
|
| 87 |
+
576
|
| 88 |
+
],
|
| 89 |
+
"ternary_scale": [
|
| 90 |
+
1
|
| 91 |
+
],
|
| 92 |
+
"trace_token_weights": [
|
| 93 |
+
2793
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
"lowercase": false,
|
| 97 |
+
"default_reasoning_profile": "none",
|
| 98 |
+
"attention": "none",
|
| 99 |
+
"transformer": "false",
|
| 100 |
+
"weight_derivation": "computed analytical/statistical checkpoint from OkeyMeta curriculum data; no Transformer attention stack",
|
| 101 |
+
"context_model": "recurrent persistent memory state; practical limits depend on runtime session and machine memory",
|
| 102 |
+
"current_release": "public base checkpoint",
|
| 103 |
+
"next_line": "tool-aware and web-freshness data line is being computed after this release",
|
| 104 |
+
"public_version": "v1",
|
| 105 |
+
"internal_compute_run": "v95",
|
| 106 |
+
"internal_source_checkpoint": "reframr-v95-500b-effective-fullreadout-outside-probe-generalization-e96-s48.safetensors"
|
| 107 |
+
}
|
examples/jsonl_serve.ps1
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
$requests = @'
|
| 2 |
+
{"prompt":"Who are you, and who built you?","max_tokens":80,"temperature":0.9}
|
| 3 |
+
{"system":"Answer in two short paragraphs and use exactly one fitting emoji.","prompt":"Encourage a tired engineer who is still building carefully.","max_tokens":80,"temperature":0.95}
|
| 4 |
+
{"prompt":"Tell a short story about a glass library under the sea.","max_tokens":120,"temperature":1.05,"decode_top_k":90}
|
| 5 |
+
'@
|
| 6 |
+
|
| 7 |
+
$requests | python -m reframr serve --model model.safetensors --max-tokens 96 --temperature 0.92 --decode-top-k 72 --decode-top-p 0.92
|
examples/python_inference.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 8 |
+
if str(REPO_ROOT) not in sys.path:
|
| 9 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 10 |
+
|
| 11 |
+
from reframr.model import ReframrModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main() -> None:
|
| 15 |
+
parser = argparse.ArgumentParser(description="Run Reframr-RFM-v1-Base locally.")
|
| 16 |
+
parser.add_argument("--model", default=str(REPO_ROOT / "model.safetensors"))
|
| 17 |
+
parser.add_argument("--prompt", default="Who are you, and what makes Reframr different?")
|
| 18 |
+
parser.add_argument("--system", default="")
|
| 19 |
+
parser.add_argument("--max-tokens", type=int, default=90)
|
| 20 |
+
parser.add_argument("--temperature", type=float, default=0.92)
|
| 21 |
+
parser.add_argument("--top-k", type=int, default=72)
|
| 22 |
+
parser.add_argument("--top-p", type=float, default=0.92)
|
| 23 |
+
parser.add_argument("--repetition-penalty", type=float, default=1.18)
|
| 24 |
+
args = parser.parse_args()
|
| 25 |
+
|
| 26 |
+
context = args.prompt
|
| 27 |
+
if args.system.strip():
|
| 28 |
+
context = f"System instruction: {args.system.strip()}\nUser: {args.prompt}"
|
| 29 |
+
|
| 30 |
+
model = ReframrModel.load(args.model)
|
| 31 |
+
print(
|
| 32 |
+
model.generate_text(
|
| 33 |
+
context,
|
| 34 |
+
max_tokens=args.max_tokens,
|
| 35 |
+
temperature=args.temperature,
|
| 36 |
+
top_k=args.top_k,
|
| 37 |
+
top_p=args.top_p,
|
| 38 |
+
repetition_penalty=args.repetition_penalty,
|
| 39 |
+
)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
main()
|
generation_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_tokens": 96,
|
| 3 |
+
"temperature": 0.92,
|
| 4 |
+
"decode_top_k": 72,
|
| 5 |
+
"decode_top_p": 0.92,
|
| 6 |
+
"repetition_penalty": 1.18,
|
| 7 |
+
"reasoning_profile": "none"
|
| 8 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28d9eb4844b8aa4e337c18bf78e5b12fcf214b876fb5cd2e6e1fa556c7f70f2b
|
| 3 |
+
size 205798796
|
pyproject.toml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "reframr"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "CPU-first analytical language modeling research framework for REFRAMR."
|
| 5 |
+
requires-python = ">=3.13"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"numpy>=2.1,<3",
|
| 8 |
+
"scipy>=1.14,<2",
|
| 9 |
+
"datasets>=4.1,<5",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
[project.scripts]
|
| 13 |
+
reframr = "reframr.cli:main"
|
| 14 |
+
|
| 15 |
+
[build-system]
|
| 16 |
+
requires = ["setuptools>=68"]
|
| 17 |
+
build-backend = "setuptools.build_meta"
|
reframr/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
_VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
|
| 5 |
+
for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
|
| 6 |
+
if _vendor_path.exists():
|
| 7 |
+
vendor_text = str(_vendor_path)
|
| 8 |
+
if vendor_text not in sys.path:
|
| 9 |
+
sys.path.insert(0, vendor_text)
|
| 10 |
+
|
| 11 |
+
from .checkpoint import inspect_checkpoint, read_safetensor_file
|
| 12 |
+
from .config import ReframrConfig
|
| 13 |
+
from .embeddings import EmbeddingModel, fit_ppmi_embedding
|
| 14 |
+
from .hippo import AnalyticalMemoryUnit, hippo_legs_matrix
|
| 15 |
+
from .model import ReframrModel
|
| 16 |
+
from .reasoning import REASONING_CONTROL_TOKENS, REASONING_PROFILES, TOKENIZER_NAME
|
| 17 |
+
from .tokenizer import NativeTokenizer
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"AnalyticalMemoryUnit",
|
| 21 |
+
"EmbeddingModel",
|
| 22 |
+
"NativeTokenizer",
|
| 23 |
+
"REASONING_CONTROL_TOKENS",
|
| 24 |
+
"REASONING_PROFILES",
|
| 25 |
+
"ReframrConfig",
|
| 26 |
+
"ReframrModel",
|
| 27 |
+
"TOKENIZER_NAME",
|
| 28 |
+
"fit_ppmi_embedding",
|
| 29 |
+
"hippo_legs_matrix",
|
| 30 |
+
"inspect_checkpoint",
|
| 31 |
+
"read_safetensor_file",
|
| 32 |
+
]
|
reframr/__main__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .cli import main
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
raise SystemExit(main())
|
reframr/checkpoint.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import math
|
| 3 |
+
import site
|
| 4 |
+
import struct
|
| 5 |
+
import sys
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
_VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
|
| 11 |
+
for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
|
| 12 |
+
if _vendor_path.exists():
|
| 13 |
+
vendor_text = str(_vendor_path)
|
| 14 |
+
if vendor_text not in sys.path:
|
| 15 |
+
sys.path.insert(0, vendor_text)
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import numpy as np
|
| 19 |
+
except ModuleNotFoundError:
|
| 20 |
+
user_site = site.getusersitepackages()
|
| 21 |
+
if user_site and user_site not in sys.path:
|
| 22 |
+
sys.path.append(user_site)
|
| 23 |
+
try:
|
| 24 |
+
import numpy as np
|
| 25 |
+
except ModuleNotFoundError:
|
| 26 |
+
np = None
|
| 27 |
+
|
| 28 |
+
if np is not None and not hasattr(np, "asarray"):
|
| 29 |
+
np = None
|
| 30 |
+
|
| 31 |
+
DTYPE_CODES = {
|
| 32 |
+
"F32": ("f", 4),
|
| 33 |
+
"F64": ("d", 8),
|
| 34 |
+
"I32": ("i", 4),
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass(slots=True)
|
| 39 |
+
class SafeTensorFile:
|
| 40 |
+
tensors: dict[str, Any]
|
| 41 |
+
metadata: dict[str, str]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _read_safetensor_header(path: str | Path) -> dict[str, Any]:
|
| 45 |
+
with Path(path).open("rb") as handle:
|
| 46 |
+
length_bytes = handle.read(8)
|
| 47 |
+
if len(length_bytes) < 8:
|
| 48 |
+
raise ValueError("Invalid safetensors file: missing header length.")
|
| 49 |
+
header_length = struct.unpack("<Q", length_bytes)[0]
|
| 50 |
+
header_bytes = handle.read(header_length)
|
| 51 |
+
if len(header_bytes) != header_length:
|
| 52 |
+
raise ValueError("Invalid safetensors file: truncated header.")
|
| 53 |
+
return json.loads(header_bytes.decode("utf-8"))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _shape_of(value: Any) -> list[int]:
|
| 57 |
+
if np is not None and hasattr(value, "shape"):
|
| 58 |
+
return [int(axis) for axis in value.shape]
|
| 59 |
+
if not isinstance(value, list):
|
| 60 |
+
return []
|
| 61 |
+
if not value:
|
| 62 |
+
return [0]
|
| 63 |
+
first_shape = _shape_of(value[0])
|
| 64 |
+
for item in value[1:]:
|
| 65 |
+
if _shape_of(item) != first_shape:
|
| 66 |
+
raise ValueError("Safetensor writer does not support ragged tensors.")
|
| 67 |
+
return [len(value)] + first_shape
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _flatten(value: Any) -> list[Any]:
|
| 71 |
+
if np is not None and hasattr(value, "reshape"):
|
| 72 |
+
return value.reshape(-1).tolist()
|
| 73 |
+
if isinstance(value, list):
|
| 74 |
+
flattened: list[Any] = []
|
| 75 |
+
for item in value:
|
| 76 |
+
flattened.extend(_flatten(item))
|
| 77 |
+
return flattened
|
| 78 |
+
return [value]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _dtype_of(flat_values: list[Any]) -> str:
|
| 82 |
+
if all(isinstance(value, int) and not isinstance(value, bool) for value in flat_values):
|
| 83 |
+
return "I32"
|
| 84 |
+
return "F64"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _pack_tensor(dtype: str, values: list[Any]) -> bytes:
|
| 88 |
+
if not values:
|
| 89 |
+
return b""
|
| 90 |
+
code, _ = DTYPE_CODES[dtype]
|
| 91 |
+
cast_values = [int(value) for value in values] if dtype == "I32" else [float(value) for value in values]
|
| 92 |
+
return struct.pack(f"<{len(cast_values)}{code}", *cast_values)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _array_payload(value: Any) -> tuple[str, list[int], Any] | None:
|
| 96 |
+
if np is None:
|
| 97 |
+
return None
|
| 98 |
+
try:
|
| 99 |
+
array = np.asarray(value)
|
| 100 |
+
except (TypeError, ValueError):
|
| 101 |
+
return None
|
| 102 |
+
if array.dtype == object:
|
| 103 |
+
return None
|
| 104 |
+
shape = [int(axis) for axis in array.shape]
|
| 105 |
+
if np.issubdtype(array.dtype, np.integer) and not np.issubdtype(array.dtype, np.bool_):
|
| 106 |
+
return "I32", shape, np.ascontiguousarray(array.astype("<i4", copy=False))
|
| 107 |
+
if np.issubdtype(array.dtype, np.floating):
|
| 108 |
+
if array.dtype == np.float32:
|
| 109 |
+
return "F32", shape, np.ascontiguousarray(array.astype("<f4", copy=False))
|
| 110 |
+
return "F64", shape, np.ascontiguousarray(array.astype("<f8", copy=False))
|
| 111 |
+
return "F64", shape, np.ascontiguousarray(array.astype("<f8", copy=False))
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _reshape(values: list[Any], shape: list[int]) -> Any:
|
| 115 |
+
if not shape:
|
| 116 |
+
return values[0]
|
| 117 |
+
if len(shape) == 1:
|
| 118 |
+
return values[: shape[0]]
|
| 119 |
+
|
| 120 |
+
chunk = math.prod(shape[1:])
|
| 121 |
+
return [
|
| 122 |
+
_reshape(values[index * chunk : (index + 1) * chunk], shape[1:])
|
| 123 |
+
for index in range(shape[0])
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def write_safetensor_file(
|
| 128 |
+
path: str | Path,
|
| 129 |
+
tensors: dict[str, Any],
|
| 130 |
+
*,
|
| 131 |
+
metadata: dict[str, str] | None = None,
|
| 132 |
+
) -> None:
|
| 133 |
+
tensor_header: dict[str, Any] = {}
|
| 134 |
+
payloads: list[Any] = []
|
| 135 |
+
offset = 0
|
| 136 |
+
|
| 137 |
+
for name, value in tensors.items():
|
| 138 |
+
array_payload = _array_payload(value)
|
| 139 |
+
if array_payload is None:
|
| 140 |
+
flat_values = _flatten(value)
|
| 141 |
+
dtype = _dtype_of(flat_values)
|
| 142 |
+
shape = _shape_of(value)
|
| 143 |
+
payload = _pack_tensor(dtype, flat_values)
|
| 144 |
+
else:
|
| 145 |
+
dtype, shape, payload = array_payload
|
| 146 |
+
payload_size = int(payload.nbytes) if hasattr(payload, "nbytes") else len(payload)
|
| 147 |
+
tensor_header[name] = {
|
| 148 |
+
"dtype": dtype,
|
| 149 |
+
"shape": shape,
|
| 150 |
+
"data_offsets": [offset, offset + payload_size],
|
| 151 |
+
}
|
| 152 |
+
payloads.append(payload)
|
| 153 |
+
offset += payload_size
|
| 154 |
+
|
| 155 |
+
if metadata:
|
| 156 |
+
tensor_header["__metadata__"] = metadata
|
| 157 |
+
|
| 158 |
+
header_bytes = json.dumps(tensor_header, separators=(",", ":")).encode("utf-8")
|
| 159 |
+
output_path = Path(path)
|
| 160 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 161 |
+
with output_path.open("wb") as handle:
|
| 162 |
+
handle.write(struct.pack("<Q", len(header_bytes)))
|
| 163 |
+
handle.write(header_bytes)
|
| 164 |
+
for payload in payloads:
|
| 165 |
+
if hasattr(payload, "nbytes"):
|
| 166 |
+
if payload.nbytes:
|
| 167 |
+
handle.write(memoryview(payload).cast("B"))
|
| 168 |
+
else:
|
| 169 |
+
handle.write(payload)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def read_safetensor_file(path: str | Path, *, arrays: bool = False) -> SafeTensorFile:
|
| 173 |
+
tensor_path = Path(path)
|
| 174 |
+
if arrays and np is not None:
|
| 175 |
+
with tensor_path.open("rb") as handle:
|
| 176 |
+
length_bytes = handle.read(8)
|
| 177 |
+
if len(length_bytes) < 8:
|
| 178 |
+
raise ValueError("Invalid safetensors file: missing header length.")
|
| 179 |
+
header_length = struct.unpack("<Q", length_bytes)[0]
|
| 180 |
+
header_bytes = handle.read(header_length)
|
| 181 |
+
if len(header_bytes) != header_length:
|
| 182 |
+
raise ValueError("Invalid safetensors file: truncated header.")
|
| 183 |
+
header = json.loads(header_bytes.decode("utf-8"))
|
| 184 |
+
data_start = 8 + header_length
|
| 185 |
+
metadata = {str(key): str(value) for key, value in header.get("__metadata__", {}).items()}
|
| 186 |
+
tensors: dict[str, Any] = {}
|
| 187 |
+
|
| 188 |
+
for name, spec in header.items():
|
| 189 |
+
if name == "__metadata__":
|
| 190 |
+
continue
|
| 191 |
+
start, end = spec["data_offsets"]
|
| 192 |
+
dtype = str(spec["dtype"])
|
| 193 |
+
shape = [int(value) for value in spec["shape"]]
|
| 194 |
+
_, width = DTYPE_CODES[dtype]
|
| 195 |
+
payload_width = end - start
|
| 196 |
+
element_count = payload_width // width if width else 0
|
| 197 |
+
if payload_width <= 0:
|
| 198 |
+
tensors[name] = np.asarray([], dtype={"I32": "<i4", "F32": "<f4", "F64": "<f8"}[dtype])
|
| 199 |
+
continue
|
| 200 |
+
array_dtype = {"I32": "<i4", "F32": "<f4", "F64": "<f8"}[dtype]
|
| 201 |
+
mapped_shape = tuple(shape) if shape else (element_count,)
|
| 202 |
+
mapped = np.memmap(
|
| 203 |
+
tensor_path,
|
| 204 |
+
dtype=array_dtype,
|
| 205 |
+
mode="r",
|
| 206 |
+
offset=data_start + start,
|
| 207 |
+
shape=mapped_shape,
|
| 208 |
+
order="C",
|
| 209 |
+
)
|
| 210 |
+
tensors[name] = mapped if shape else mapped[0]
|
| 211 |
+
|
| 212 |
+
return SafeTensorFile(tensors=tensors, metadata=metadata)
|
| 213 |
+
|
| 214 |
+
raw = tensor_path.read_bytes()
|
| 215 |
+
if len(raw) < 8:
|
| 216 |
+
raise ValueError("Invalid safetensors file: missing header length.")
|
| 217 |
+
|
| 218 |
+
header_length = struct.unpack("<Q", raw[:8])[0]
|
| 219 |
+
header = json.loads(raw[8 : 8 + header_length].decode("utf-8"))
|
| 220 |
+
data_buffer = raw[8 + header_length :]
|
| 221 |
+
metadata = {str(key): str(value) for key, value in header.get("__metadata__", {}).items()}
|
| 222 |
+
tensors: dict[str, Any] = {}
|
| 223 |
+
|
| 224 |
+
for name, spec in header.items():
|
| 225 |
+
if name == "__metadata__":
|
| 226 |
+
continue
|
| 227 |
+
start, end = spec["data_offsets"]
|
| 228 |
+
dtype = str(spec["dtype"])
|
| 229 |
+
shape = [int(value) for value in spec["shape"]]
|
| 230 |
+
code, width = DTYPE_CODES[dtype]
|
| 231 |
+
payload = data_buffer[start:end]
|
| 232 |
+
element_count = len(payload) // width if width else 0
|
| 233 |
+
if np is not None and payload:
|
| 234 |
+
array_dtype = {"I32": "<i4", "F32": "<f4", "F64": "<f8"}[dtype]
|
| 235 |
+
values = np.frombuffer(payload, dtype=array_dtype, count=element_count)
|
| 236 |
+
reshaped = values.reshape(shape) if shape else values
|
| 237 |
+
if arrays:
|
| 238 |
+
tensors[name] = reshaped.copy() if shape else values.copy()[0]
|
| 239 |
+
else:
|
| 240 |
+
tensors[name] = reshaped.tolist() if shape else values.tolist()[0]
|
| 241 |
+
else:
|
| 242 |
+
values = list(struct.unpack(f"<{element_count}{code}", payload)) if payload else []
|
| 243 |
+
tensors[name] = _reshape(values, shape)
|
| 244 |
+
|
| 245 |
+
return SafeTensorFile(tensors=tensors, metadata=metadata)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def inspect_checkpoint(path: str | Path) -> dict[str, Any]:
|
| 249 |
+
header = _read_safetensor_header(path)
|
| 250 |
+
metadata = {str(key): str(value) for key, value in header.get("__metadata__", {}).items()}
|
| 251 |
+
tensor_names = sorted(name for name in header if name != "__metadata__")
|
| 252 |
+
config = json.loads(metadata["config"]) if "config" in metadata else {}
|
| 253 |
+
return {
|
| 254 |
+
"format": "safetensors",
|
| 255 |
+
"path": str(Path(path).resolve()),
|
| 256 |
+
"checkpoint_kind": metadata.get("checkpoint_kind", "unknown"),
|
| 257 |
+
"schema_version": metadata.get("schema_version", "0"),
|
| 258 |
+
"tokenizer_name": metadata.get("tokenizer_name", ""),
|
| 259 |
+
"default_reasoning_profile": str(config.get("default_reasoning_profile", "none")) if config else "none",
|
| 260 |
+
"lowercase": bool(config.get("lowercase", False)) if config else False,
|
| 261 |
+
"tensor_count": len(tensor_names),
|
| 262 |
+
"tensor_names": tensor_names,
|
| 263 |
+
"tensor_dtypes": {
|
| 264 |
+
name: str(header[name]["dtype"])
|
| 265 |
+
for name in tensor_names
|
| 266 |
+
},
|
| 267 |
+
"tensor_shapes": {
|
| 268 |
+
name: [int(axis) for axis in header[name]["shape"]]
|
| 269 |
+
for name in tensor_names
|
| 270 |
+
},
|
| 271 |
+
"tokenizer_vocab_size": int(metadata.get("tokenizer_vocab_size", "0")),
|
| 272 |
+
"embedding_dim": int(config.get("embedding_dim", 0)) if config else 0,
|
| 273 |
+
"state_dim": int(config.get("state_dim", 0)) if config else 0,
|
| 274 |
+
}
|
reframr/cli.py
ADDED
|
@@ -0,0 +1,760 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from .checkpoint import inspect_checkpoint
|
| 7 |
+
from .config import ReframrConfig
|
| 8 |
+
from .corpus_recipes import (
|
| 9 |
+
build_foundation_corpus,
|
| 10 |
+
build_generalization_corpus,
|
| 11 |
+
write_corpus_package,
|
| 12 |
+
)
|
| 13 |
+
from .curriculum import CurriculumConfig, write_curriculum_package
|
| 14 |
+
from .datasets import load_prompt_suite, load_text_corpus
|
| 15 |
+
from .evaluation import benchmark_open_prompts, evaluate_manifest, load_manifest
|
| 16 |
+
from .hf_import import import_hf_dataset
|
| 17 |
+
from .model import ReframrModel
|
| 18 |
+
from .reasoning import REASONING_PROFILES, TOKENIZER_NAME, reasoning_prefix
|
| 19 |
+
from .streaming import fit_model_from_corpus_plan, load_corpus_plan
|
| 20 |
+
from .tokenizer import MAX_TOKENIZER_VOCAB_SIZE, clamp_vocab_size, recommend_vocab_size
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def configure_stdio() -> None:
|
| 24 |
+
for stream in (sys.stdout, sys.stderr):
|
| 25 |
+
reconfigure = getattr(stream, "reconfigure", None)
|
| 26 |
+
if reconfigure is not None:
|
| 27 |
+
reconfigure(encoding="utf-8")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 31 |
+
parser = argparse.ArgumentParser(
|
| 32 |
+
prog="reframr",
|
| 33 |
+
description="Compute and query REFRAMR analytical language model checkpoints.",
|
| 34 |
+
)
|
| 35 |
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
| 36 |
+
|
| 37 |
+
compute = subparsers.add_parser(
|
| 38 |
+
"compute",
|
| 39 |
+
aliases=["train"],
|
| 40 |
+
help="Compute a REFRAMR checkpoint from a text corpus with no epoch loop.",
|
| 41 |
+
)
|
| 42 |
+
compute.add_argument(
|
| 43 |
+
"--input",
|
| 44 |
+
required=True,
|
| 45 |
+
help="Path to a text, JSON, or JSONL corpus file, or a directory of such files.",
|
| 46 |
+
)
|
| 47 |
+
compute.add_argument("--output", required=True, help="Path to write the .safetensors checkpoint.")
|
| 48 |
+
compute.add_argument("--embedding-dim", type=int, default=16)
|
| 49 |
+
compute.add_argument("--state-dim", type=int, default=32)
|
| 50 |
+
compute.add_argument("--timescales", default="1.0,0.5,0.25,0.125")
|
| 51 |
+
compute.add_argument("--window-size", type=int, default=2)
|
| 52 |
+
compute.add_argument("--regularization", type=float, default=1e-3)
|
| 53 |
+
compute.add_argument("--min-frequency", type=int, default=1)
|
| 54 |
+
compute.add_argument(
|
| 55 |
+
"--max-vocab",
|
| 56 |
+
type=int,
|
| 57 |
+
default=256,
|
| 58 |
+
help="Cap analytical embedding vocabulary to keep weight computation fast on CPU.",
|
| 59 |
+
)
|
| 60 |
+
compute.add_argument("--tokenizer-vocab-size", type=int, default=0)
|
| 61 |
+
compute.add_argument("--tokenizer-min-pair-frequency", type=int, default=2)
|
| 62 |
+
compute.add_argument(
|
| 63 |
+
"--max-training-examples",
|
| 64 |
+
type=int,
|
| 65 |
+
default=60000,
|
| 66 |
+
help="Cap sampled recurrent training states while still reading the full corpus for tokenizer, embeddings, and transitions.",
|
| 67 |
+
)
|
| 68 |
+
compute.add_argument(
|
| 69 |
+
"--max-transition-contexts",
|
| 70 |
+
type=int,
|
| 71 |
+
default=4096,
|
| 72 |
+
help="Keep only the strongest learned transition contexts per order. Use 0 to disable the cap.",
|
| 73 |
+
)
|
| 74 |
+
compute.add_argument(
|
| 75 |
+
"--max-transition-next-tokens",
|
| 76 |
+
type=int,
|
| 77 |
+
default=4,
|
| 78 |
+
help="Keep this many learned next-token choices per transition context.",
|
| 79 |
+
)
|
| 80 |
+
case_group = compute.add_mutually_exclusive_group()
|
| 81 |
+
case_group.add_argument(
|
| 82 |
+
"--lowercase",
|
| 83 |
+
action="store_true",
|
| 84 |
+
help="Normalize corpus text to lowercase before tokenization.",
|
| 85 |
+
)
|
| 86 |
+
case_group.add_argument("--preserve-case", action="store_true", help=argparse.SUPPRESS)
|
| 87 |
+
compute.add_argument(
|
| 88 |
+
"--reasoning-profile",
|
| 89 |
+
choices=sorted(REASONING_PROFILES),
|
| 90 |
+
default="none",
|
| 91 |
+
help="Default reasoning-control profile baked into the checkpoint.",
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
recompute = subparsers.add_parser(
|
| 95 |
+
"recompute",
|
| 96 |
+
help="Compute a REFRAMR checkpoint from a streaming corpus plan with no raw-text cache.",
|
| 97 |
+
)
|
| 98 |
+
recompute.add_argument("--plan", required=True, help="Path to a streaming corpus plan JSON file.")
|
| 99 |
+
recompute.add_argument("--output", required=True, help="Path to write the .safetensors checkpoint.")
|
| 100 |
+
recompute.add_argument("--embedding-dim", type=int, default=16)
|
| 101 |
+
recompute.add_argument("--state-dim", type=int, default=32)
|
| 102 |
+
recompute.add_argument("--timescales", default="1.0,0.5,0.25,0.125")
|
| 103 |
+
recompute.add_argument("--window-size", type=int, default=2)
|
| 104 |
+
recompute.add_argument("--regularization", type=float, default=1e-3)
|
| 105 |
+
recompute.add_argument("--min-frequency", type=int, default=1)
|
| 106 |
+
recompute.add_argument("--max-vocab", type=int, default=256)
|
| 107 |
+
recompute.add_argument("--tokenizer-vocab-size", type=int, default=0)
|
| 108 |
+
recompute.add_argument("--tokenizer-min-pair-frequency", type=int, default=2)
|
| 109 |
+
recompute.add_argument("--max-training-examples", type=int, default=60000)
|
| 110 |
+
recompute.add_argument("--max-transition-contexts", type=int, default=4096)
|
| 111 |
+
recompute.add_argument("--max-transition-next-tokens", type=int, default=4)
|
| 112 |
+
recompute.add_argument("--log-every", type=int, default=0)
|
| 113 |
+
recompute_case_group = recompute.add_mutually_exclusive_group()
|
| 114 |
+
recompute_case_group.add_argument("--lowercase", action="store_true")
|
| 115 |
+
recompute_case_group.add_argument("--preserve-case", action="store_true", help=argparse.SUPPRESS)
|
| 116 |
+
recompute.add_argument(
|
| 117 |
+
"--reasoning-profile",
|
| 118 |
+
choices=sorted(REASONING_PROFILES),
|
| 119 |
+
default="none",
|
| 120 |
+
help="Default reasoning-control profile baked into the checkpoint.",
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
predict = subparsers.add_parser("predict", help="Predict the next-token distribution from a saved model.")
|
| 124 |
+
predict.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
|
| 125 |
+
predict.add_argument("--context", required=True, help="Input context text.")
|
| 126 |
+
predict.add_argument("--top-k", type=int, default=5)
|
| 127 |
+
predict.add_argument(
|
| 128 |
+
"--reasoning-mode",
|
| 129 |
+
choices=sorted(REASONING_PROFILES),
|
| 130 |
+
default=None,
|
| 131 |
+
help="Override the checkpoint's default reasoning-control profile.",
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
generate = subparsers.add_parser("generate", help="Generate long-form text from a saved model.")
|
| 135 |
+
generate.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
|
| 136 |
+
generate.add_argument("--context", required=True, help="Prompt or starting context text.")
|
| 137 |
+
generate.add_argument("--system", default="", help="Optional system instruction to prepend as learned context.")
|
| 138 |
+
generate.add_argument("--max-tokens", type=int, default=64)
|
| 139 |
+
generate.add_argument("--temperature", type=float, default=0.82)
|
| 140 |
+
generate.add_argument("--decode-top-k", type=int, default=24)
|
| 141 |
+
generate.add_argument("--decode-top-p", type=float, default=0.92)
|
| 142 |
+
generate.add_argument("--repetition-penalty", type=float, default=1.18)
|
| 143 |
+
generate.add_argument(
|
| 144 |
+
"--reasoning-mode",
|
| 145 |
+
choices=sorted(REASONING_PROFILES),
|
| 146 |
+
default=None,
|
| 147 |
+
help="Override the checkpoint's default reasoning-control profile.",
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
generate_batch = subparsers.add_parser(
|
| 151 |
+
"generate-batch",
|
| 152 |
+
help="Generate answers for a prompt file while keeping one checkpoint loaded.",
|
| 153 |
+
)
|
| 154 |
+
generate_batch.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
|
| 155 |
+
generate_batch.add_argument("--prompts", required=True, help="Path to a TXT, JSON, or JSONL prompt suite.")
|
| 156 |
+
generate_batch.add_argument("--output", required=True, help="Path to write JSONL generations.")
|
| 157 |
+
generate_batch.add_argument("--max-tokens", type=int, default=64)
|
| 158 |
+
generate_batch.add_argument("--temperature", type=float, default=0.82)
|
| 159 |
+
generate_batch.add_argument("--decode-top-k", type=int, default=24)
|
| 160 |
+
generate_batch.add_argument("--decode-top-p", type=float, default=0.92)
|
| 161 |
+
generate_batch.add_argument("--repetition-penalty", type=float, default=1.18)
|
| 162 |
+
generate_batch.add_argument(
|
| 163 |
+
"--reasoning-mode",
|
| 164 |
+
choices=sorted(REASONING_PROFILES),
|
| 165 |
+
default=None,
|
| 166 |
+
help="Override the checkpoint's default reasoning-control profile.",
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
serve = subparsers.add_parser(
|
| 170 |
+
"serve",
|
| 171 |
+
help="Keep one checkpoint loaded and answer JSONL generation requests from stdin.",
|
| 172 |
+
)
|
| 173 |
+
serve.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
|
| 174 |
+
serve.add_argument("--max-tokens", type=int, default=64)
|
| 175 |
+
serve.add_argument("--temperature", type=float, default=0.82)
|
| 176 |
+
serve.add_argument("--decode-top-k", type=int, default=24)
|
| 177 |
+
serve.add_argument("--decode-top-p", type=float, default=0.92)
|
| 178 |
+
serve.add_argument("--repetition-penalty", type=float, default=1.18)
|
| 179 |
+
serve.add_argument(
|
| 180 |
+
"--reasoning-mode",
|
| 181 |
+
choices=sorted(REASONING_PROFILES),
|
| 182 |
+
default=None,
|
| 183 |
+
help="Override the checkpoint's default reasoning-control profile.",
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
trace = subparsers.add_parser("trace", help="Trace REFRAMR reasoning components through generation steps.")
|
| 187 |
+
trace.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
|
| 188 |
+
trace.add_argument("--context", required=True, help="Prompt or starting context text.")
|
| 189 |
+
trace.add_argument("--max-tokens", type=int, default=8)
|
| 190 |
+
trace.add_argument("--top-k", type=int, default=5)
|
| 191 |
+
trace.add_argument("--temperature", type=float, default=0.82)
|
| 192 |
+
trace.add_argument("--decode-top-p", type=float, default=0.92)
|
| 193 |
+
trace.add_argument("--repetition-penalty", type=float, default=1.18)
|
| 194 |
+
trace.add_argument(
|
| 195 |
+
"--reasoning-mode",
|
| 196 |
+
choices=sorted(REASONING_PROFILES),
|
| 197 |
+
default=None,
|
| 198 |
+
help="Override the checkpoint's default reasoning-control profile.",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
inspect = subparsers.add_parser("inspect", help="Inspect a REFRAMR safetensors checkpoint.")
|
| 202 |
+
inspect.add_argument("--model", required=True, help="Path to a .safetensors checkpoint.")
|
| 203 |
+
|
| 204 |
+
craft = subparsers.add_parser(
|
| 205 |
+
"craft-corpus",
|
| 206 |
+
help="Generate a JSON-first bootstrap corpus, manifest, and generalization prompt suite.",
|
| 207 |
+
)
|
| 208 |
+
craft.add_argument("--output-dir", required=True, help="Directory to write corpus and manifest files.")
|
| 209 |
+
craft.add_argument(
|
| 210 |
+
"--variant",
|
| 211 |
+
choices=("foundation", "generalization"),
|
| 212 |
+
default="foundation",
|
| 213 |
+
help="Choose between the mixed foundation corpus and the language-first generalization corpus.",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
craft_curriculum = subparsers.add_parser(
|
| 217 |
+
"craft-curriculum",
|
| 218 |
+
help="Generate the OkeyMeta JSON curriculum shard, manifest, holdout prompts, and recompute plan.",
|
| 219 |
+
)
|
| 220 |
+
craft_curriculum.add_argument("--output-dir", required=True, help="Directory to write curriculum files.")
|
| 221 |
+
craft_curriculum.add_argument(
|
| 222 |
+
"--records-per-category",
|
| 223 |
+
type=int,
|
| 224 |
+
default=1000,
|
| 225 |
+
help="How many JSON records to generate for each curriculum category.",
|
| 226 |
+
)
|
| 227 |
+
craft_curriculum.add_argument("--seed", type=int, default=7)
|
| 228 |
+
craft_curriculum.add_argument("--train-ratio", type=float, default=0.92)
|
| 229 |
+
craft_curriculum.add_argument(
|
| 230 |
+
"--effective-token-target",
|
| 231 |
+
type=int,
|
| 232 |
+
default=0,
|
| 233 |
+
help="Set plan weighting so compact curriculum statistics represent this many effective tokens.",
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
evaluate = subparsers.add_parser(
|
| 237 |
+
"evaluate",
|
| 238 |
+
help="Evaluate memorization and held-out generalization from a benchmark manifest.",
|
| 239 |
+
)
|
| 240 |
+
evaluate.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.")
|
| 241 |
+
evaluate.add_argument("--manifest", required=True, help="Path to a corpus benchmark manifest JSON file.")
|
| 242 |
+
evaluate.add_argument(
|
| 243 |
+
"--reasoning-mode",
|
| 244 |
+
choices=sorted(REASONING_PROFILES),
|
| 245 |
+
default=None,
|
| 246 |
+
help="Override the checkpoint's default reasoning-control profile during evaluation.",
|
| 247 |
+
)
|
| 248 |
+
evaluate.add_argument("--top-k", type=int, default=5)
|
| 249 |
+
|
| 250 |
+
benchmark_open = subparsers.add_parser(
|
| 251 |
+
"benchmark-open",
|
| 252 |
+
help="Run arbitrary prompt files through a checkpoint with open-ended output metrics.",
|
| 253 |
+
)
|
| 254 |
+
benchmark_open.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.")
|
| 255 |
+
benchmark_open.add_argument("--prompts", required=True, help="Path to a TXT, JSON, or JSONL prompt suite.")
|
| 256 |
+
benchmark_open.add_argument("--max-tokens", type=int, default=64)
|
| 257 |
+
benchmark_open.add_argument("--temperature", type=float, default=0.82)
|
| 258 |
+
benchmark_open.add_argument("--decode-top-k", type=int, default=24)
|
| 259 |
+
benchmark_open.add_argument("--decode-top-p", type=float, default=0.92)
|
| 260 |
+
benchmark_open.add_argument("--repetition-penalty", type=float, default=1.18)
|
| 261 |
+
benchmark_open.add_argument(
|
| 262 |
+
"--reasoning-mode",
|
| 263 |
+
choices=sorted(REASONING_PROFILES),
|
| 264 |
+
default=None,
|
| 265 |
+
help="Override the checkpoint's default reasoning-control profile during benchmarking.",
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
import_hf = subparsers.add_parser(
|
| 269 |
+
"import-hf",
|
| 270 |
+
help="Import Hugging Face dataset text into the REFRAMR JSON record standard.",
|
| 271 |
+
)
|
| 272 |
+
import_hf.add_argument("--dataset", required=True, help="Hugging Face dataset id.")
|
| 273 |
+
import_hf.add_argument("--output", required=True, help="Path to write the JSONL corpus.")
|
| 274 |
+
import_hf.add_argument("--config", default=None, help="Optional dataset config/subset.")
|
| 275 |
+
import_hf.add_argument("--split", default="train", help="Dataset split to import.")
|
| 276 |
+
import_hf.add_argument("--text-field", default=None, help="Explicit text column name.")
|
| 277 |
+
import_hf.add_argument("--limit", type=int, default=1000, help="Maximum records to import.")
|
| 278 |
+
import_hf.add_argument(
|
| 279 |
+
"--min-words",
|
| 280 |
+
type=int,
|
| 281 |
+
default=0,
|
| 282 |
+
help="Drop imported records shorter than this many words.",
|
| 283 |
+
)
|
| 284 |
+
import_hf.add_argument(
|
| 285 |
+
"--max-words",
|
| 286 |
+
type=int,
|
| 287 |
+
default=0,
|
| 288 |
+
help="Drop imported records longer than this many words. Use 0 to disable.",
|
| 289 |
+
)
|
| 290 |
+
import_hf.add_argument(
|
| 291 |
+
"--min-alpha-ratio",
|
| 292 |
+
type=float,
|
| 293 |
+
default=0.0,
|
| 294 |
+
help="Drop imported records whose alphabetic-character ratio falls below this threshold.",
|
| 295 |
+
)
|
| 296 |
+
import_hf.add_argument(
|
| 297 |
+
"--allowed-languages",
|
| 298 |
+
default="",
|
| 299 |
+
help="Optional comma-separated language codes to keep, such as en,yo,ig,ha.",
|
| 300 |
+
)
|
| 301 |
+
import_hf.add_argument(
|
| 302 |
+
"--preference-target",
|
| 303 |
+
choices=("both", "chosen", "rejected"),
|
| 304 |
+
default="chosen",
|
| 305 |
+
help="When importing preference datasets, keep both sides or only the chosen/rejected side.",
|
| 306 |
+
)
|
| 307 |
+
import_hf.add_argument(
|
| 308 |
+
"--no-streaming",
|
| 309 |
+
action="store_true",
|
| 310 |
+
help="Disable streaming dataset reads.",
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
return parser
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def parse_timescales(raw_timescales: str) -> tuple[float, ...]:
|
| 317 |
+
values = [segment.strip() for segment in raw_timescales.split(",") if segment.strip()]
|
| 318 |
+
if not values:
|
| 319 |
+
raise ValueError("At least one timescale is required.")
|
| 320 |
+
return tuple(float(value) for value in values)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def command_compute(args: argparse.Namespace) -> int:
|
| 324 |
+
text = load_text_corpus(args.input)
|
| 325 |
+
requested_vocab_size = args.tokenizer_vocab_size or recommend_vocab_size(
|
| 326 |
+
text,
|
| 327 |
+
lowercase=args.lowercase,
|
| 328 |
+
)
|
| 329 |
+
tokenizer_vocab_size = clamp_vocab_size(requested_vocab_size)
|
| 330 |
+
config = ReframrConfig(
|
| 331 |
+
embedding_dim=args.embedding_dim,
|
| 332 |
+
state_dim=args.state_dim,
|
| 333 |
+
timescales=parse_timescales(args.timescales),
|
| 334 |
+
window_size=args.window_size,
|
| 335 |
+
regularization=args.regularization,
|
| 336 |
+
min_frequency=args.min_frequency,
|
| 337 |
+
max_vocab=args.max_vocab,
|
| 338 |
+
tokenizer_vocab_size=tokenizer_vocab_size,
|
| 339 |
+
tokenizer_min_pair_frequency=args.tokenizer_min_pair_frequency,
|
| 340 |
+
max_training_examples=args.max_training_examples,
|
| 341 |
+
max_transition_contexts_per_order=(
|
| 342 |
+
args.max_transition_contexts if args.max_transition_contexts > 0 else None
|
| 343 |
+
),
|
| 344 |
+
max_transition_next_tokens=args.max_transition_next_tokens,
|
| 345 |
+
lowercase=args.lowercase,
|
| 346 |
+
default_reasoning_profile=args.reasoning_profile,
|
| 347 |
+
)
|
| 348 |
+
model = ReframrModel(config).fit(text)
|
| 349 |
+
model.save(args.output)
|
| 350 |
+
|
| 351 |
+
assert model.tokenizer is not None
|
| 352 |
+
assert model.embedding_model is not None
|
| 353 |
+
summary = {
|
| 354 |
+
"status": "computed",
|
| 355 |
+
"format": "safetensors",
|
| 356 |
+
"model_path": str(Path(args.output).resolve()),
|
| 357 |
+
"tokenizer_name": TOKENIZER_NAME,
|
| 358 |
+
"vocab_size": len(model.embedding_model.id_to_token),
|
| 359 |
+
"tokenizer_vocab_budget": config.tokenizer_vocab_size,
|
| 360 |
+
"tokenizer_vocab_budget_max": MAX_TOKENIZER_VOCAB_SIZE,
|
| 361 |
+
"tokenizer_vocab_size": model.tokenizer.vocab_size,
|
| 362 |
+
"reasoning_profile": config.default_reasoning_profile,
|
| 363 |
+
"reasoning_tokens": reasoning_prefix(config.default_reasoning_profile),
|
| 364 |
+
"lowercase": config.lowercase,
|
| 365 |
+
"max_training_examples": config.max_training_examples,
|
| 366 |
+
"max_transition_contexts_per_order": config.max_transition_contexts_per_order,
|
| 367 |
+
"max_transition_next_tokens": config.max_transition_next_tokens,
|
| 368 |
+
"embedding_dim": config.embedding_dim,
|
| 369 |
+
"state_dim": config.state_dim,
|
| 370 |
+
"timescales": list(config.timescales),
|
| 371 |
+
}
|
| 372 |
+
print(json.dumps(summary))
|
| 373 |
+
return 0
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def command_recompute(args: argparse.Namespace) -> int:
|
| 377 |
+
plan = load_corpus_plan(args.plan)
|
| 378 |
+
requested_vocab_size = args.tokenizer_vocab_size or 1024
|
| 379 |
+
tokenizer_vocab_size = clamp_vocab_size(requested_vocab_size)
|
| 380 |
+
config = ReframrConfig(
|
| 381 |
+
embedding_dim=args.embedding_dim,
|
| 382 |
+
state_dim=args.state_dim,
|
| 383 |
+
timescales=parse_timescales(args.timescales),
|
| 384 |
+
window_size=args.window_size,
|
| 385 |
+
regularization=args.regularization,
|
| 386 |
+
min_frequency=args.min_frequency,
|
| 387 |
+
max_vocab=args.max_vocab,
|
| 388 |
+
tokenizer_vocab_size=tokenizer_vocab_size,
|
| 389 |
+
tokenizer_min_pair_frequency=args.tokenizer_min_pair_frequency,
|
| 390 |
+
max_training_examples=args.max_training_examples,
|
| 391 |
+
max_transition_contexts_per_order=(
|
| 392 |
+
args.max_transition_contexts if args.max_transition_contexts > 0 else None
|
| 393 |
+
),
|
| 394 |
+
max_transition_next_tokens=args.max_transition_next_tokens,
|
| 395 |
+
lowercase=args.lowercase,
|
| 396 |
+
default_reasoning_profile=args.reasoning_profile,
|
| 397 |
+
)
|
| 398 |
+
model, payload = fit_model_from_corpus_plan(
|
| 399 |
+
plan,
|
| 400 |
+
config,
|
| 401 |
+
log_every=args.log_every,
|
| 402 |
+
)
|
| 403 |
+
model.save(args.output)
|
| 404 |
+
|
| 405 |
+
summary = {
|
| 406 |
+
"status": "recomputed",
|
| 407 |
+
"format": "safetensors",
|
| 408 |
+
"streaming": True,
|
| 409 |
+
"plan_path": str(Path(args.plan).resolve()),
|
| 410 |
+
"model_path": str(Path(args.output).resolve()),
|
| 411 |
+
"tokenizer_name": TOKENIZER_NAME,
|
| 412 |
+
"tokenizer_vocab_budget": config.tokenizer_vocab_size,
|
| 413 |
+
"tokenizer_vocab_budget_max": MAX_TOKENIZER_VOCAB_SIZE,
|
| 414 |
+
"tokenizer_vocab_size": payload["tokenizer_vocab_size"],
|
| 415 |
+
"vocab_size": payload["embedding_vocab_size"],
|
| 416 |
+
"documents_processed": payload["documents_processed"],
|
| 417 |
+
"source_counts": payload["source_counts"],
|
| 418 |
+
"examples_processed": payload["examples_processed"],
|
| 419 |
+
"associative_examples": payload["associative_examples"],
|
| 420 |
+
"answer_associative_examples": payload.get("answer_associative_examples", 0),
|
| 421 |
+
"general_associative_examples": payload.get("general_associative_examples", 0),
|
| 422 |
+
"answer_intent_examples": payload.get("answer_intent_examples", 0),
|
| 423 |
+
"answer_start_examples": payload.get("answer_start_examples", 0),
|
| 424 |
+
"answer_sequence_examples": payload.get("answer_sequence_examples", 0),
|
| 425 |
+
"prompt_answer_readout_examples": payload.get("prompt_answer_readout_examples", 0),
|
| 426 |
+
"prompt_answer_start_readout_examples": payload.get("prompt_answer_start_readout_examples", 0),
|
| 427 |
+
"preference_pairs": payload.get("preference_pairs", 0),
|
| 428 |
+
"preference_state_pairs": payload.get("preference_state_pairs", 0),
|
| 429 |
+
"stage_seconds": payload.get("stage_seconds", {}),
|
| 430 |
+
"readout_solver": payload.get("readout_solver"),
|
| 431 |
+
"reasoning_profile": config.default_reasoning_profile,
|
| 432 |
+
"reasoning_tokens": reasoning_prefix(config.default_reasoning_profile),
|
| 433 |
+
"lowercase": config.lowercase,
|
| 434 |
+
"max_training_examples": config.max_training_examples,
|
| 435 |
+
"max_transition_contexts_per_order": config.max_transition_contexts_per_order,
|
| 436 |
+
"max_transition_next_tokens": config.max_transition_next_tokens,
|
| 437 |
+
"embedding_dim": config.embedding_dim,
|
| 438 |
+
"state_dim": config.state_dim,
|
| 439 |
+
"timescales": list(config.timescales),
|
| 440 |
+
}
|
| 441 |
+
print(json.dumps(summary))
|
| 442 |
+
return 0
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def command_predict(args: argparse.Namespace) -> int:
|
| 446 |
+
model = ReframrModel.load(args.model)
|
| 447 |
+
distribution = model.predict_next_distribution(
|
| 448 |
+
args.context,
|
| 449 |
+
reasoning_mode=args.reasoning_mode,
|
| 450 |
+
)
|
| 451 |
+
predictions = sorted(
|
| 452 |
+
distribution.items(),
|
| 453 |
+
key=lambda item: item[1],
|
| 454 |
+
reverse=True,
|
| 455 |
+
)[: args.top_k]
|
| 456 |
+
payload = {
|
| 457 |
+
"context": args.context,
|
| 458 |
+
"reasoning_mode": args.reasoning_mode or model.config.default_reasoning_profile,
|
| 459 |
+
"reasoning_tokens": reasoning_prefix(args.reasoning_mode or model.config.default_reasoning_profile),
|
| 460 |
+
"predictions": [
|
| 461 |
+
{"token": token, "probability": probability}
|
| 462 |
+
for token, probability in predictions
|
| 463 |
+
],
|
| 464 |
+
}
|
| 465 |
+
print(json.dumps(payload))
|
| 466 |
+
return 0
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def command_generate(args: argparse.Namespace) -> int:
|
| 470 |
+
model = ReframrModel.load(args.model)
|
| 471 |
+
context = compose_generation_context(args.context, system=args.system)
|
| 472 |
+
generated_text = model.generate_text(
|
| 473 |
+
context,
|
| 474 |
+
max_tokens=args.max_tokens,
|
| 475 |
+
reasoning_mode=args.reasoning_mode,
|
| 476 |
+
temperature=args.temperature,
|
| 477 |
+
top_k=args.decode_top_k,
|
| 478 |
+
top_p=args.decode_top_p,
|
| 479 |
+
repetition_penalty=args.repetition_penalty,
|
| 480 |
+
)
|
| 481 |
+
payload = {
|
| 482 |
+
"context": context,
|
| 483 |
+
"reasoning_mode": args.reasoning_mode or model.config.default_reasoning_profile,
|
| 484 |
+
"reasoning_tokens": reasoning_prefix(args.reasoning_mode or model.config.default_reasoning_profile),
|
| 485 |
+
"generated_token_count": len(generated_text.split()),
|
| 486 |
+
"generated_text": generated_text,
|
| 487 |
+
}
|
| 488 |
+
print(json.dumps(payload))
|
| 489 |
+
return 0
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def compose_generation_context(prompt: str, *, system: str = "") -> str:
|
| 493 |
+
clean_prompt = prompt.strip()
|
| 494 |
+
clean_system = system.strip()
|
| 495 |
+
if not clean_system:
|
| 496 |
+
return clean_prompt
|
| 497 |
+
return f"System instruction: {clean_system}\nUser: {clean_prompt}"
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def command_generate_batch(args: argparse.Namespace) -> int:
|
| 501 |
+
model = ReframrModel.load(args.model)
|
| 502 |
+
prompts = load_prompt_suite(args.prompts)
|
| 503 |
+
output_path = Path(args.output)
|
| 504 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 505 |
+
active_mode = args.reasoning_mode or model.config.default_reasoning_profile
|
| 506 |
+
rows: list[dict[str, object]] = []
|
| 507 |
+
with output_path.open("w", encoding="utf-8") as handle:
|
| 508 |
+
for index, record in enumerate(prompts):
|
| 509 |
+
prompt = str(record["prompt"])
|
| 510 |
+
context = compose_generation_context(
|
| 511 |
+
prompt,
|
| 512 |
+
system=str(record.get("system", "")),
|
| 513 |
+
)
|
| 514 |
+
max_tokens = int(record.get("max_tokens", args.max_tokens))
|
| 515 |
+
generated_text = model.generate_text(
|
| 516 |
+
context,
|
| 517 |
+
max_tokens=max_tokens,
|
| 518 |
+
reasoning_mode=args.reasoning_mode,
|
| 519 |
+
temperature=args.temperature,
|
| 520 |
+
top_k=args.decode_top_k,
|
| 521 |
+
top_p=args.decode_top_p,
|
| 522 |
+
repetition_penalty=args.repetition_penalty,
|
| 523 |
+
)
|
| 524 |
+
row = {
|
| 525 |
+
"index": index,
|
| 526 |
+
"prompt": prompt,
|
| 527 |
+
"context": context,
|
| 528 |
+
"system": record.get("system", ""),
|
| 529 |
+
"tags": record.get("tags", []),
|
| 530 |
+
"reasoning_mode": active_mode,
|
| 531 |
+
"reasoning_tokens": reasoning_prefix(active_mode),
|
| 532 |
+
"generated_token_count": len(generated_text.split()),
|
| 533 |
+
"generated_text": generated_text,
|
| 534 |
+
}
|
| 535 |
+
rows.append(row)
|
| 536 |
+
handle.write(json.dumps(row, ensure_ascii=False, separators=(",", ":")) + "\n")
|
| 537 |
+
payload = {
|
| 538 |
+
"status": "generated",
|
| 539 |
+
"sample_count": len(rows),
|
| 540 |
+
"model_path": str(Path(args.model).resolve()),
|
| 541 |
+
"prompts_path": str(Path(args.prompts).resolve()),
|
| 542 |
+
"output_path": str(output_path.resolve()),
|
| 543 |
+
"model_loads": 1,
|
| 544 |
+
}
|
| 545 |
+
print(json.dumps(payload))
|
| 546 |
+
return 0
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def command_serve(args: argparse.Namespace) -> int:
|
| 550 |
+
model = ReframrModel.load(args.model)
|
| 551 |
+
default_mode = args.reasoning_mode or model.config.default_reasoning_profile
|
| 552 |
+
for index, raw_line in enumerate(sys.stdin):
|
| 553 |
+
line = raw_line.strip()
|
| 554 |
+
if not line:
|
| 555 |
+
continue
|
| 556 |
+
try:
|
| 557 |
+
request = json.loads(line)
|
| 558 |
+
except json.JSONDecodeError as exc:
|
| 559 |
+
response = {
|
| 560 |
+
"index": index,
|
| 561 |
+
"error": "invalid_json",
|
| 562 |
+
"message": str(exc),
|
| 563 |
+
"model_loads": 1,
|
| 564 |
+
}
|
| 565 |
+
sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
|
| 566 |
+
sys.stdout.flush()
|
| 567 |
+
continue
|
| 568 |
+
if isinstance(request, str):
|
| 569 |
+
context = request
|
| 570 |
+
request_payload: dict[str, object] = {}
|
| 571 |
+
elif isinstance(request, dict):
|
| 572 |
+
request_payload = request
|
| 573 |
+
raw_context = str(request_payload.get("prompt", request_payload.get("context", "")))
|
| 574 |
+
context = compose_generation_context(
|
| 575 |
+
raw_context,
|
| 576 |
+
system=str(request_payload.get("system", "")),
|
| 577 |
+
)
|
| 578 |
+
else:
|
| 579 |
+
response = {
|
| 580 |
+
"index": index,
|
| 581 |
+
"error": "invalid_request",
|
| 582 |
+
"message": "request must be a JSON object or string",
|
| 583 |
+
"model_loads": 1,
|
| 584 |
+
}
|
| 585 |
+
sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
|
| 586 |
+
sys.stdout.flush()
|
| 587 |
+
continue
|
| 588 |
+
active_mode = str(request_payload.get("reasoning_mode", default_mode))
|
| 589 |
+
max_tokens = int(request_payload.get("max_tokens", args.max_tokens))
|
| 590 |
+
temperature = float(request_payload.get("temperature", args.temperature))
|
| 591 |
+
top_k = int(request_payload.get("decode_top_k", args.decode_top_k))
|
| 592 |
+
top_p = float(request_payload.get("decode_top_p", args.decode_top_p))
|
| 593 |
+
repetition_penalty = float(
|
| 594 |
+
request_payload.get("repetition_penalty", args.repetition_penalty)
|
| 595 |
+
)
|
| 596 |
+
generated_text = model.generate_text(
|
| 597 |
+
context,
|
| 598 |
+
max_tokens=max_tokens,
|
| 599 |
+
reasoning_mode=active_mode,
|
| 600 |
+
temperature=temperature,
|
| 601 |
+
top_k=top_k,
|
| 602 |
+
top_p=top_p,
|
| 603 |
+
repetition_penalty=repetition_penalty,
|
| 604 |
+
)
|
| 605 |
+
response = {
|
| 606 |
+
"index": index,
|
| 607 |
+
"context": context,
|
| 608 |
+
"reasoning_mode": active_mode,
|
| 609 |
+
"reasoning_tokens": reasoning_prefix(active_mode),
|
| 610 |
+
"generated_token_count": len(generated_text.split()),
|
| 611 |
+
"generated_text": generated_text,
|
| 612 |
+
"model_loads": 1,
|
| 613 |
+
}
|
| 614 |
+
sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
|
| 615 |
+
sys.stdout.flush()
|
| 616 |
+
return 0
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def command_trace(args: argparse.Namespace) -> int:
|
| 620 |
+
model = ReframrModel.load(args.model)
|
| 621 |
+
payload = model.trace_generation(
|
| 622 |
+
args.context,
|
| 623 |
+
max_tokens=args.max_tokens,
|
| 624 |
+
reasoning_mode=args.reasoning_mode,
|
| 625 |
+
top_k=args.top_k,
|
| 626 |
+
temperature=args.temperature,
|
| 627 |
+
top_p=args.decode_top_p,
|
| 628 |
+
repetition_penalty=args.repetition_penalty,
|
| 629 |
+
)
|
| 630 |
+
print(json.dumps(payload))
|
| 631 |
+
return 0
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def command_inspect(args: argparse.Namespace) -> int:
|
| 635 |
+
print(json.dumps(inspect_checkpoint(args.model)))
|
| 636 |
+
return 0
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def command_craft_corpus(args: argparse.Namespace) -> int:
|
| 640 |
+
package = (
|
| 641 |
+
build_generalization_corpus()
|
| 642 |
+
if args.variant == "generalization"
|
| 643 |
+
else build_foundation_corpus()
|
| 644 |
+
)
|
| 645 |
+
paths = write_corpus_package(package, args.output_dir)
|
| 646 |
+
payload = {
|
| 647 |
+
"name": package.name,
|
| 648 |
+
"corpus_path": paths["corpus_path"],
|
| 649 |
+
"manifest_path": paths["manifest_path"],
|
| 650 |
+
"prompt_suite_path": paths["prompt_suite_path"],
|
| 651 |
+
"token_count_estimate": len(package.text.split()),
|
| 652 |
+
"memorization_samples": len(package.memorization_samples),
|
| 653 |
+
"generalization_samples": len(package.generalization_samples),
|
| 654 |
+
"generalization_prompt_count": len(package.open_ended_samples),
|
| 655 |
+
"variant": args.variant,
|
| 656 |
+
"section_counts": package.section_counts,
|
| 657 |
+
}
|
| 658 |
+
print(json.dumps(payload))
|
| 659 |
+
return 0
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def command_craft_curriculum(args: argparse.Namespace) -> int:
|
| 663 |
+
payload = write_curriculum_package(
|
| 664 |
+
args.output_dir,
|
| 665 |
+
CurriculumConfig(
|
| 666 |
+
records_per_category=args.records_per_category,
|
| 667 |
+
seed=args.seed,
|
| 668 |
+
train_ratio=args.train_ratio,
|
| 669 |
+
),
|
| 670 |
+
effective_token_target=args.effective_token_target or None,
|
| 671 |
+
)
|
| 672 |
+
print(json.dumps(payload))
|
| 673 |
+
return 0
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def command_evaluate(args: argparse.Namespace) -> int:
|
| 677 |
+
model = ReframrModel.load(args.model)
|
| 678 |
+
manifest = load_manifest(args.manifest)
|
| 679 |
+
payload = evaluate_manifest(
|
| 680 |
+
model,
|
| 681 |
+
manifest,
|
| 682 |
+
reasoning_mode=args.reasoning_mode,
|
| 683 |
+
top_k=args.top_k,
|
| 684 |
+
)
|
| 685 |
+
print(json.dumps(payload))
|
| 686 |
+
return 0
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
def command_benchmark_open(args: argparse.Namespace) -> int:
|
| 690 |
+
model = ReframrModel.load(args.model)
|
| 691 |
+
prompts = load_prompt_suite(args.prompts)
|
| 692 |
+
payload = benchmark_open_prompts(
|
| 693 |
+
model,
|
| 694 |
+
prompts,
|
| 695 |
+
reasoning_mode=args.reasoning_mode,
|
| 696 |
+
max_tokens=args.max_tokens,
|
| 697 |
+
temperature=args.temperature,
|
| 698 |
+
top_k=args.decode_top_k,
|
| 699 |
+
top_p=args.decode_top_p,
|
| 700 |
+
repetition_penalty=args.repetition_penalty,
|
| 701 |
+
)
|
| 702 |
+
print(json.dumps(payload))
|
| 703 |
+
return 0
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def command_import_hf(args: argparse.Namespace) -> int:
|
| 707 |
+
payload = import_hf_dataset(
|
| 708 |
+
dataset=args.dataset,
|
| 709 |
+
output_path=args.output,
|
| 710 |
+
config=args.config,
|
| 711 |
+
split=args.split,
|
| 712 |
+
text_field=args.text_field,
|
| 713 |
+
limit=args.limit,
|
| 714 |
+
streaming=not args.no_streaming,
|
| 715 |
+
preference_target=args.preference_target,
|
| 716 |
+
min_words=args.min_words,
|
| 717 |
+
max_words=args.max_words,
|
| 718 |
+
min_alpha_ratio=args.min_alpha_ratio,
|
| 719 |
+
allowed_languages=tuple(
|
| 720 |
+
segment.strip()
|
| 721 |
+
for segment in args.allowed_languages.split(",")
|
| 722 |
+
if segment.strip()
|
| 723 |
+
),
|
| 724 |
+
)
|
| 725 |
+
print(json.dumps(payload))
|
| 726 |
+
return 0
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def main(argv: list[str] | None = None) -> int:
|
| 730 |
+
configure_stdio()
|
| 731 |
+
parser = build_parser()
|
| 732 |
+
args = parser.parse_args(argv)
|
| 733 |
+
if args.command in {"compute", "train"}:
|
| 734 |
+
return command_compute(args)
|
| 735 |
+
if args.command == "recompute":
|
| 736 |
+
return command_recompute(args)
|
| 737 |
+
if args.command == "predict":
|
| 738 |
+
return command_predict(args)
|
| 739 |
+
if args.command == "generate":
|
| 740 |
+
return command_generate(args)
|
| 741 |
+
if args.command == "generate-batch":
|
| 742 |
+
return command_generate_batch(args)
|
| 743 |
+
if args.command == "serve":
|
| 744 |
+
return command_serve(args)
|
| 745 |
+
if args.command == "trace":
|
| 746 |
+
return command_trace(args)
|
| 747 |
+
if args.command == "inspect":
|
| 748 |
+
return command_inspect(args)
|
| 749 |
+
if args.command == "craft-corpus":
|
| 750 |
+
return command_craft_corpus(args)
|
| 751 |
+
if args.command == "craft-curriculum":
|
| 752 |
+
return command_craft_curriculum(args)
|
| 753 |
+
if args.command == "evaluate":
|
| 754 |
+
return command_evaluate(args)
|
| 755 |
+
if args.command == "benchmark-open":
|
| 756 |
+
return command_benchmark_open(args)
|
| 757 |
+
if args.command == "import-hf":
|
| 758 |
+
return command_import_hf(args)
|
| 759 |
+
parser.error(f"Unknown command: {args.command}")
|
| 760 |
+
return 2
|
reframr/config.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@dataclass(slots=True)
|
| 5 |
+
class ReframrConfig:
|
| 6 |
+
embedding_dim: int = 16
|
| 7 |
+
state_dim: int = 32
|
| 8 |
+
timescales: tuple[float, ...] = (1.0, 0.5, 0.25, 0.125)
|
| 9 |
+
window_size: int = 2
|
| 10 |
+
regularization: float = 1e-3
|
| 11 |
+
min_frequency: int = 1
|
| 12 |
+
max_vocab: int | None = 256
|
| 13 |
+
tokenizer_vocab_size: int = 256
|
| 14 |
+
tokenizer_min_pair_frequency: int = 2
|
| 15 |
+
max_training_examples: int | None = 60000
|
| 16 |
+
max_transition_contexts_per_order: int | None = 4096
|
| 17 |
+
max_transition_next_tokens: int = 4
|
| 18 |
+
lowercase: bool = False
|
| 19 |
+
default_reasoning_profile: str = "none"
|
| 20 |
+
|
| 21 |
+
def to_dict(self) -> dict[str, object]:
|
| 22 |
+
return {
|
| 23 |
+
"embedding_dim": self.embedding_dim,
|
| 24 |
+
"state_dim": self.state_dim,
|
| 25 |
+
"timescales": list(self.timescales),
|
| 26 |
+
"window_size": self.window_size,
|
| 27 |
+
"regularization": self.regularization,
|
| 28 |
+
"min_frequency": self.min_frequency,
|
| 29 |
+
"max_vocab": self.max_vocab,
|
| 30 |
+
"tokenizer_vocab_size": self.tokenizer_vocab_size,
|
| 31 |
+
"tokenizer_min_pair_frequency": self.tokenizer_min_pair_frequency,
|
| 32 |
+
"max_training_examples": self.max_training_examples,
|
| 33 |
+
"max_transition_contexts_per_order": self.max_transition_contexts_per_order,
|
| 34 |
+
"max_transition_next_tokens": self.max_transition_next_tokens,
|
| 35 |
+
"lowercase": self.lowercase,
|
| 36 |
+
"default_reasoning_profile": self.default_reasoning_profile,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def from_dict(cls, payload: dict[str, object]) -> "ReframrConfig":
|
| 41 |
+
return cls(
|
| 42 |
+
embedding_dim=int(payload["embedding_dim"]),
|
| 43 |
+
state_dim=int(payload["state_dim"]),
|
| 44 |
+
timescales=tuple(float(value) for value in payload["timescales"]),
|
| 45 |
+
window_size=int(payload["window_size"]),
|
| 46 |
+
regularization=float(payload["regularization"]),
|
| 47 |
+
min_frequency=int(payload["min_frequency"]),
|
| 48 |
+
max_vocab=(
|
| 49 |
+
int(payload.get("max_vocab", 256))
|
| 50 |
+
if payload.get("max_vocab", 256) is not None
|
| 51 |
+
else None
|
| 52 |
+
),
|
| 53 |
+
tokenizer_vocab_size=int(payload.get("tokenizer_vocab_size", 256)),
|
| 54 |
+
tokenizer_min_pair_frequency=int(payload.get("tokenizer_min_pair_frequency", 2)),
|
| 55 |
+
max_training_examples=(
|
| 56 |
+
int(payload["max_training_examples"])
|
| 57 |
+
if payload.get("max_training_examples") is not None
|
| 58 |
+
else None
|
| 59 |
+
),
|
| 60 |
+
max_transition_contexts_per_order=(
|
| 61 |
+
int(payload["max_transition_contexts_per_order"])
|
| 62 |
+
if payload.get("max_transition_contexts_per_order") is not None
|
| 63 |
+
else None
|
| 64 |
+
),
|
| 65 |
+
max_transition_next_tokens=int(payload.get("max_transition_next_tokens", 4)),
|
| 66 |
+
lowercase=bool(payload.get("lowercase", False)),
|
| 67 |
+
default_reasoning_profile=str(payload.get("default_reasoning_profile", "none")),
|
| 68 |
+
)
|
reframr/corpus.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from collections import Counter
|
| 3 |
+
|
| 4 |
+
from .linalg import Matrix, np, zeros
|
| 5 |
+
|
| 6 |
+
TOKEN_PATTERN = re.compile(r"[A-Za-z0-9']+")
|
| 7 |
+
FRAMETOKEN_WORD_PREFIX = "▁"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def tokenize(text: str) -> list[str]:
|
| 11 |
+
return TOKEN_PATTERN.findall(text.lower())
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_vocabulary(
|
| 15 |
+
tokens: list[str],
|
| 16 |
+
min_frequency: int = 1,
|
| 17 |
+
max_vocab: int | None = None,
|
| 18 |
+
) -> tuple[dict[str, int], list[str]]:
|
| 19 |
+
counts = Counter(tokens)
|
| 20 |
+
return build_vocabulary_from_counts(
|
| 21 |
+
counts,
|
| 22 |
+
min_frequency=min_frequency,
|
| 23 |
+
max_vocab=max_vocab,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def build_vocabulary_from_counts(
|
| 28 |
+
counts: dict[str, float],
|
| 29 |
+
min_frequency: int = 1,
|
| 30 |
+
max_vocab: int | None = None,
|
| 31 |
+
) -> tuple[dict[str, int], list[str]]:
|
| 32 |
+
items = [
|
| 33 |
+
(token, count)
|
| 34 |
+
for token, count in sorted(counts.items(), key=lambda pair: (-pair[1], pair[0]))
|
| 35 |
+
if count >= min_frequency
|
| 36 |
+
]
|
| 37 |
+
if max_vocab is not None:
|
| 38 |
+
if any(_looks_like_frametoken(token) for token, _ in items):
|
| 39 |
+
items = _prioritize_frametoken_output_items(items)[:max_vocab]
|
| 40 |
+
else:
|
| 41 |
+
items = items[:max_vocab]
|
| 42 |
+
|
| 43 |
+
id_to_token = [token for token, _ in items]
|
| 44 |
+
token_to_id = {token: index for index, token in enumerate(id_to_token)}
|
| 45 |
+
return token_to_id, id_to_token
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _looks_like_frametoken(token: str) -> bool:
|
| 49 |
+
return token.startswith(FRAMETOKEN_WORD_PREFIX) or (
|
| 50 |
+
token.startswith("<") and token.endswith(">")
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _is_special_token(token: str) -> bool:
|
| 55 |
+
return token.startswith("<") and token.endswith(">")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _is_word_start_token(token: str) -> bool:
|
| 59 |
+
return token.startswith(FRAMETOKEN_WORD_PREFIX)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _is_single_letter_word_start(token: str) -> bool:
|
| 63 |
+
if not token.startswith(FRAMETOKEN_WORD_PREFIX):
|
| 64 |
+
return False
|
| 65 |
+
rendered = token[len(FRAMETOKEN_WORD_PREFIX) :]
|
| 66 |
+
return len(rendered) == 1 and rendered.isalpha() and rendered not in {"A", "I"}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _is_bare_fallback_token(token: str) -> bool:
|
| 70 |
+
return len(token) == 1 and not token.startswith(FRAMETOKEN_WORD_PREFIX)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _prioritize_frametoken_output_items(items: list[tuple[str, float]]) -> list[tuple[str, float]]:
|
| 74 |
+
# FrameToken keeps fallback characters for encoding coverage, but the model's
|
| 75 |
+
# output/readout vocabulary should spend its capped slots on answerable tokens.
|
| 76 |
+
def priority(item: tuple[str, float]) -> tuple[int, float, str]:
|
| 77 |
+
token, count = item
|
| 78 |
+
if _is_special_token(token):
|
| 79 |
+
group = 0
|
| 80 |
+
elif _is_single_letter_word_start(token):
|
| 81 |
+
group = 3
|
| 82 |
+
elif _is_word_start_token(token):
|
| 83 |
+
group = 1
|
| 84 |
+
elif _is_bare_fallback_token(token):
|
| 85 |
+
group = 4
|
| 86 |
+
else:
|
| 87 |
+
group = 2
|
| 88 |
+
return (group, -count, token)
|
| 89 |
+
|
| 90 |
+
return sorted(items, key=priority)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def build_cooccurrence_matrix(
|
| 94 |
+
tokens: list[str],
|
| 95 |
+
token_to_id: dict[str, int],
|
| 96 |
+
window_size: int,
|
| 97 |
+
) -> Matrix:
|
| 98 |
+
size = len(token_to_id)
|
| 99 |
+
token_ids = [token_to_id[token] for token in tokens if token in token_to_id]
|
| 100 |
+
if np is not None and size > 0 and token_ids:
|
| 101 |
+
matrix = np.zeros((size, size), dtype=np.float64)
|
| 102 |
+
token_array = np.asarray(token_ids, dtype=np.int64)
|
| 103 |
+
for offset in range(1, window_size + 1):
|
| 104 |
+
if len(token_array) <= offset:
|
| 105 |
+
break
|
| 106 |
+
left = token_array[:-offset]
|
| 107 |
+
right = token_array[offset:]
|
| 108 |
+
weight = 1.0 / offset
|
| 109 |
+
np.add.at(matrix, (left, right), weight)
|
| 110 |
+
np.add.at(matrix, (right, left), weight)
|
| 111 |
+
return matrix.tolist()
|
| 112 |
+
|
| 113 |
+
matrix = zeros(size, size)
|
| 114 |
+
for index, token_id in enumerate(token_ids):
|
| 115 |
+
for offset in range(1, window_size + 1):
|
| 116 |
+
other_index = index + offset
|
| 117 |
+
if other_index >= len(token_ids):
|
| 118 |
+
break
|
| 119 |
+
other_id = token_ids[other_index]
|
| 120 |
+
weight = 1.0 / offset
|
| 121 |
+
matrix[token_id][other_id] += weight
|
| 122 |
+
matrix[other_id][token_id] += weight
|
| 123 |
+
return matrix
|
reframr/corpus_recipes.py
ADDED
|
@@ -0,0 +1,1257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass(slots=True)
|
| 7 |
+
class EvalSample:
|
| 8 |
+
section: str
|
| 9 |
+
context: str
|
| 10 |
+
expected: str
|
| 11 |
+
|
| 12 |
+
def to_dict(self) -> dict[str, str]:
|
| 13 |
+
return {
|
| 14 |
+
"section": self.section,
|
| 15 |
+
"context": self.context,
|
| 16 |
+
"expected": self.expected,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass(slots=True)
|
| 21 |
+
class OpenEvalSample:
|
| 22 |
+
section: str
|
| 23 |
+
context: str
|
| 24 |
+
required_groups: list[list[str]]
|
| 25 |
+
banned_phrases: list[str]
|
| 26 |
+
min_words: int = 12
|
| 27 |
+
require_punctuation: bool = True
|
| 28 |
+
max_tokens: int = 56
|
| 29 |
+
|
| 30 |
+
def to_dict(self) -> dict[str, object]:
|
| 31 |
+
return {
|
| 32 |
+
"section": self.section,
|
| 33 |
+
"context": self.context,
|
| 34 |
+
"required_groups": self.required_groups,
|
| 35 |
+
"banned_phrases": self.banned_phrases,
|
| 36 |
+
"min_words": self.min_words,
|
| 37 |
+
"require_punctuation": self.require_punctuation,
|
| 38 |
+
"max_tokens": self.max_tokens,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass(slots=True)
|
| 43 |
+
class CorpusRecord:
|
| 44 |
+
section: str
|
| 45 |
+
context: str
|
| 46 |
+
answer: str
|
| 47 |
+
split: str = "train"
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def text(self) -> str:
|
| 51 |
+
return _line(self.context, self.answer)
|
| 52 |
+
|
| 53 |
+
def to_dict(self) -> dict[str, str]:
|
| 54 |
+
return {
|
| 55 |
+
"section": self.section,
|
| 56 |
+
"split": self.split,
|
| 57 |
+
"context": self.context,
|
| 58 |
+
"answer": self.answer,
|
| 59 |
+
"text": self.text,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass(slots=True)
|
| 64 |
+
class CorpusPackage:
|
| 65 |
+
name: str
|
| 66 |
+
records: list[CorpusRecord]
|
| 67 |
+
section_counts: dict[str, int]
|
| 68 |
+
memorization_samples: list[EvalSample]
|
| 69 |
+
generalization_samples: list[EvalSample]
|
| 70 |
+
open_ended_samples: list[OpenEvalSample]
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def slug(self) -> str:
|
| 74 |
+
return self.name.lower().replace(" ", "-")
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def text(self) -> str:
|
| 78 |
+
if not self.records:
|
| 79 |
+
return ""
|
| 80 |
+
return "\n".join(record.text for record in self.records) + "\n"
|
| 81 |
+
|
| 82 |
+
def manifest(self, *, corpus_filename: str) -> dict[str, object]:
|
| 83 |
+
return {
|
| 84 |
+
"name": self.name,
|
| 85 |
+
"corpus_filename": corpus_filename,
|
| 86 |
+
"section_counts": self.section_counts,
|
| 87 |
+
"splits": {
|
| 88 |
+
"memorization": [sample.to_dict() for sample in self.memorization_samples],
|
| 89 |
+
"generalization": [sample.to_dict() for sample in self.generalization_samples],
|
| 90 |
+
"open_ended": [sample.to_dict() for sample in self.open_ended_samples],
|
| 91 |
+
},
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
def corpus_records(self) -> list[dict[str, str]]:
|
| 95 |
+
return [record.to_dict() for record in self.records]
|
| 96 |
+
|
| 97 |
+
def prompt_suite(self) -> list[dict[str, object]]:
|
| 98 |
+
return [
|
| 99 |
+
{
|
| 100 |
+
"prompt": sample.context,
|
| 101 |
+
"tags": [sample.section, "generalization"],
|
| 102 |
+
"min_words": sample.min_words,
|
| 103 |
+
"require_punctuation": sample.require_punctuation,
|
| 104 |
+
"max_tokens": sample.max_tokens,
|
| 105 |
+
}
|
| 106 |
+
for sample in self.open_ended_samples
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _line(context: str, expected: str) -> str:
|
| 111 |
+
return f"{context} {expected}"
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _balanced_samples(samples: list[EvalSample], total: int) -> list[EvalSample]:
|
| 115 |
+
buckets: dict[str, list[EvalSample]] = {}
|
| 116 |
+
for sample in samples:
|
| 117 |
+
buckets.setdefault(sample.section, []).append(sample)
|
| 118 |
+
|
| 119 |
+
selected: list[EvalSample] = []
|
| 120 |
+
ordered_sections = sorted(buckets)
|
| 121 |
+
while len(selected) < total:
|
| 122 |
+
progressed = False
|
| 123 |
+
for section in ordered_sections:
|
| 124 |
+
bucket = buckets[section]
|
| 125 |
+
if not bucket:
|
| 126 |
+
continue
|
| 127 |
+
selected.append(bucket.pop(0))
|
| 128 |
+
progressed = True
|
| 129 |
+
if len(selected) >= total:
|
| 130 |
+
break
|
| 131 |
+
if not progressed:
|
| 132 |
+
break
|
| 133 |
+
return selected
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _recount_sections(records: list[CorpusRecord]) -> dict[str, int]:
|
| 137 |
+
counts: dict[str, int] = {}
|
| 138 |
+
for record in records:
|
| 139 |
+
counts[record.section] = counts.get(record.section, 0) + 1
|
| 140 |
+
return counts
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def build_foundation_corpus() -> CorpusPackage:
|
| 144 |
+
records: list[CorpusRecord] = []
|
| 145 |
+
lines: list[str] = []
|
| 146 |
+
section_counts: dict[str, int] = {}
|
| 147 |
+
memorization: list[EvalSample] = []
|
| 148 |
+
generalization: list[EvalSample] = []
|
| 149 |
+
open_ended: list[OpenEvalSample] = []
|
| 150 |
+
|
| 151 |
+
def add_train(section: str, context: str, expected: str, *, sample: bool = False) -> None:
|
| 152 |
+
records.append(
|
| 153 |
+
CorpusRecord(
|
| 154 |
+
section=section,
|
| 155 |
+
context=context,
|
| 156 |
+
answer=expected,
|
| 157 |
+
split="train",
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
lines.append(_line(context, expected))
|
| 161 |
+
section_counts[section] = section_counts.get(section, 0) + 1
|
| 162 |
+
if sample:
|
| 163 |
+
memorization.append(EvalSample(section=section, context=context, expected=expected))
|
| 164 |
+
|
| 165 |
+
def add_holdout(section: str, context: str, expected: str) -> None:
|
| 166 |
+
generalization.append(EvalSample(section=section, context=context, expected=expected))
|
| 167 |
+
|
| 168 |
+
def add_open(
|
| 169 |
+
section: str,
|
| 170 |
+
context: str,
|
| 171 |
+
required_groups: list[list[str]],
|
| 172 |
+
*,
|
| 173 |
+
banned_phrases: list[str],
|
| 174 |
+
min_words: int = 12,
|
| 175 |
+
require_punctuation: bool = True,
|
| 176 |
+
max_tokens: int = 56,
|
| 177 |
+
) -> None:
|
| 178 |
+
open_ended.append(
|
| 179 |
+
OpenEvalSample(
|
| 180 |
+
section=section,
|
| 181 |
+
context=context,
|
| 182 |
+
required_groups=required_groups,
|
| 183 |
+
banned_phrases=banned_phrases,
|
| 184 |
+
min_words=min_words,
|
| 185 |
+
require_punctuation=require_punctuation,
|
| 186 |
+
max_tokens=max_tokens,
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
holdout_addition = {
|
| 191 |
+
(2, 19),
|
| 192 |
+
(3, 17),
|
| 193 |
+
(4, 16),
|
| 194 |
+
(5, 15),
|
| 195 |
+
(6, 14),
|
| 196 |
+
(7, 13),
|
| 197 |
+
(8, 12),
|
| 198 |
+
(9, 11),
|
| 199 |
+
(10, 10),
|
| 200 |
+
(11, 9),
|
| 201 |
+
(12, 8),
|
| 202 |
+
(13, 7),
|
| 203 |
+
(14, 6),
|
| 204 |
+
(15, 5),
|
| 205 |
+
(16, 4),
|
| 206 |
+
(17, 3),
|
| 207 |
+
(18, 2),
|
| 208 |
+
(19, 21),
|
| 209 |
+
(20, 22),
|
| 210 |
+
(21, 19),
|
| 211 |
+
(22, 20),
|
| 212 |
+
(23, 18),
|
| 213 |
+
(24, 17),
|
| 214 |
+
(25, 16),
|
| 215 |
+
}
|
| 216 |
+
holdout_successor = {23, 29, 31, 37, 41, 43, 47, 53, 61, 67, 71, 73, 79}
|
| 217 |
+
holdout_predecessor = {24, 30, 32, 38, 42, 44, 48, 54, 62, 68, 72, 74, 80}
|
| 218 |
+
holdout_explain_addition = {
|
| 219 |
+
(7, 9),
|
| 220 |
+
(8, 11),
|
| 221 |
+
(10, 13),
|
| 222 |
+
(12, 15),
|
| 223 |
+
(14, 9),
|
| 224 |
+
(15, 14),
|
| 225 |
+
(16, 12),
|
| 226 |
+
(18, 7),
|
| 227 |
+
}
|
| 228 |
+
holdout_explain_subtraction = {
|
| 229 |
+
(19, 7),
|
| 230 |
+
(22, 9),
|
| 231 |
+
(25, 11),
|
| 232 |
+
(28, 13),
|
| 233 |
+
(31, 15),
|
| 234 |
+
(34, 12),
|
| 235 |
+
}
|
| 236 |
+
holdout_explain_multiplication = {
|
| 237 |
+
(6, 7),
|
| 238 |
+
(7, 8),
|
| 239 |
+
(8, 9),
|
| 240 |
+
(9, 6),
|
| 241 |
+
(11, 5),
|
| 242 |
+
(12, 6),
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
for left in range(1, 41):
|
| 246 |
+
for right in range(1, 41):
|
| 247 |
+
context = f"<reason> add {left} plus {right} equals <answer>"
|
| 248 |
+
expected = str(left + right)
|
| 249 |
+
if (left, right) in holdout_addition:
|
| 250 |
+
add_holdout("arithmetic", context, expected)
|
| 251 |
+
else:
|
| 252 |
+
add_train("arithmetic", context, expected, sample=(left + right) % 5 == 0)
|
| 253 |
+
|
| 254 |
+
holdout_subtraction = {
|
| 255 |
+
(9, 4),
|
| 256 |
+
(12, 5),
|
| 257 |
+
(15, 6),
|
| 258 |
+
(18, 7),
|
| 259 |
+
(21, 8),
|
| 260 |
+
(24, 9),
|
| 261 |
+
(27, 10),
|
| 262 |
+
(30, 11),
|
| 263 |
+
}
|
| 264 |
+
for left in range(3, 56):
|
| 265 |
+
for right in range(1, min(left, 21)):
|
| 266 |
+
context = f"<reason> subtract {right} from {left} equals <answer>"
|
| 267 |
+
expected = str(left - right)
|
| 268 |
+
if (left, right) in holdout_subtraction:
|
| 269 |
+
add_holdout("arithmetic", context, expected)
|
| 270 |
+
else:
|
| 271 |
+
add_train("arithmetic", context, expected, sample=(left - right) % 6 == 0)
|
| 272 |
+
|
| 273 |
+
holdout_multiplication = {
|
| 274 |
+
(7, 8),
|
| 275 |
+
(8, 9),
|
| 276 |
+
(9, 7),
|
| 277 |
+
(11, 6),
|
| 278 |
+
(12, 7),
|
| 279 |
+
(6, 11),
|
| 280 |
+
}
|
| 281 |
+
for left in range(2, 21):
|
| 282 |
+
for right in range(2, 21):
|
| 283 |
+
context = f"<reason> multiply {left} times {right} equals <answer>"
|
| 284 |
+
expected = str(left * right)
|
| 285 |
+
if (left, right) in holdout_multiplication:
|
| 286 |
+
add_holdout("arithmetic", context, expected)
|
| 287 |
+
else:
|
| 288 |
+
add_train("arithmetic", context, expected, sample=(left * right) % 9 == 0)
|
| 289 |
+
|
| 290 |
+
holdout_parity = {33, 37, 41, 45, 52, 58}
|
| 291 |
+
for value in range(1, 141):
|
| 292 |
+
context = f"<reason> parity of {value} is <answer>"
|
| 293 |
+
expected = "even" if value % 2 == 0 else "odd"
|
| 294 |
+
if value in holdout_parity:
|
| 295 |
+
add_holdout("arithmetic", context, expected)
|
| 296 |
+
else:
|
| 297 |
+
add_train("arithmetic", context, expected, sample=value % 10 == 0)
|
| 298 |
+
|
| 299 |
+
for value in range(1, 181):
|
| 300 |
+
successor_context = f"<reason> successor of {value} is <answer>"
|
| 301 |
+
successor_expected = str(value + 1)
|
| 302 |
+
if value in holdout_successor:
|
| 303 |
+
add_holdout("sequence", successor_context, successor_expected)
|
| 304 |
+
else:
|
| 305 |
+
add_train("sequence", successor_context, successor_expected, sample=value % 7 == 0)
|
| 306 |
+
|
| 307 |
+
predecessor_context = f"<reason> predecessor of {value} is <answer>"
|
| 308 |
+
predecessor_expected = str(value - 1)
|
| 309 |
+
if value in holdout_predecessor:
|
| 310 |
+
add_holdout("sequence", predecessor_context, predecessor_expected)
|
| 311 |
+
else:
|
| 312 |
+
add_train("sequence", predecessor_context, predecessor_expected, sample=value % 8 == 0)
|
| 313 |
+
|
| 314 |
+
for left in range(2, 25):
|
| 315 |
+
for right in range(2, 25):
|
| 316 |
+
context = f"<reason> explain the sum of {left} and {right} <answer>"
|
| 317 |
+
expected = (
|
| 318 |
+
f"Use {left} and {right} as the two addends; their total is "
|
| 319 |
+
f"{left + right}, so the answer is {left + right}."
|
| 320 |
+
)
|
| 321 |
+
if (left, right) in holdout_explain_addition:
|
| 322 |
+
add_holdout("reasoning", context, expected)
|
| 323 |
+
else:
|
| 324 |
+
add_train("reasoning", context, expected, sample=(left + right) % 7 == 0)
|
| 325 |
+
|
| 326 |
+
for left in range(8, 45):
|
| 327 |
+
for right in range(2, min(left, 17)):
|
| 328 |
+
context = f"<reason> explain the difference between {left} and {right} <answer>"
|
| 329 |
+
expected = (
|
| 330 |
+
f"Start with {left} and remove {right}; the remaining value is "
|
| 331 |
+
f"{left - right}, so the answer is {left - right}."
|
| 332 |
+
)
|
| 333 |
+
if (left, right) in holdout_explain_subtraction:
|
| 334 |
+
add_holdout("reasoning", context, expected)
|
| 335 |
+
else:
|
| 336 |
+
add_train("reasoning", context, expected, sample=(left - right) % 8 == 0)
|
| 337 |
+
|
| 338 |
+
for left in range(2, 17):
|
| 339 |
+
for right in range(2, 13):
|
| 340 |
+
context = f"<reason> explain the product of {left} and {right} <answer>"
|
| 341 |
+
expected = (
|
| 342 |
+
f"Treat {left} and {right} as factors; combining the equal groups gives "
|
| 343 |
+
f"{left * right}, so the answer is {left * right}."
|
| 344 |
+
)
|
| 345 |
+
if (left, right) in holdout_explain_multiplication:
|
| 346 |
+
add_holdout("reasoning", context, expected)
|
| 347 |
+
else:
|
| 348 |
+
add_train("reasoning", context, expected, sample=(left * right) % 9 == 0)
|
| 349 |
+
|
| 350 |
+
capitals = [
|
| 351 |
+
("japan", "tokyo"),
|
| 352 |
+
("brazil", "brasilia"),
|
| 353 |
+
("canada", "ottawa"),
|
| 354 |
+
("france", "paris"),
|
| 355 |
+
("germany", "berlin"),
|
| 356 |
+
("india", "new delhi"),
|
| 357 |
+
("australia", "canberra"),
|
| 358 |
+
("egypt", "cairo"),
|
| 359 |
+
("kenya", "nairobi"),
|
| 360 |
+
("mexico", "mexico city"),
|
| 361 |
+
("norway", "oslo"),
|
| 362 |
+
("chile", "santiago"),
|
| 363 |
+
("argentina", "buenos aires"),
|
| 364 |
+
("thailand", "bangkok"),
|
| 365 |
+
("indonesia", "jakarta"),
|
| 366 |
+
("morocco", "rabat"),
|
| 367 |
+
("sweden", "stockholm"),
|
| 368 |
+
("finland", "helsinki"),
|
| 369 |
+
("peru", "lima"),
|
| 370 |
+
("colombia", "bogota"),
|
| 371 |
+
]
|
| 372 |
+
for country, capital in capitals:
|
| 373 |
+
add_train(
|
| 374 |
+
"memory",
|
| 375 |
+
f"<memory> capital of {country} is <answer>",
|
| 376 |
+
capital,
|
| 377 |
+
sample=country in {"japan", "brazil", "canada", "france", "india", "kenya"},
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
analogies_train = [
|
| 381 |
+
("bird", "nest", "bee", "hive"),
|
| 382 |
+
("fish", "water", "camel", "desert"),
|
| 383 |
+
("painter", "brush", "writer", "pen"),
|
| 384 |
+
("doctor", "hospital", "teacher", "school"),
|
| 385 |
+
("farmer", "field", "captain", "ship"),
|
| 386 |
+
("judge", "court", "chef", "kitchen"),
|
| 387 |
+
("astronomer", "telescope", "musician", "violin"),
|
| 388 |
+
("pilot", "cockpit", "driver", "garage"),
|
| 389 |
+
("programmer", "code", "architect", "blueprint"),
|
| 390 |
+
("tailor", "needle", "carpenter", "hammer"),
|
| 391 |
+
("sailor", "compass", "hiker", "map"),
|
| 392 |
+
("chemist", "laboratory", "baker", "oven"),
|
| 393 |
+
("photographer", "camera", "sculptor", "chisel"),
|
| 394 |
+
("gardener", "soil", "potter", "clay"),
|
| 395 |
+
("librarian", "catalog", "analyst", "report"),
|
| 396 |
+
("surfer", "wave", "skater", "ramp"),
|
| 397 |
+
("director", "script", "conductor", "score"),
|
| 398 |
+
("nurse", "clinic", "lawyer", "firm"),
|
| 399 |
+
]
|
| 400 |
+
analogies_holdout = [
|
| 401 |
+
("curator", "museum", "editor", "journal"),
|
| 402 |
+
("beekeeper", "apiary", "farmer", "barn"),
|
| 403 |
+
("surgeon", "scalpel", "artist", "canvas"),
|
| 404 |
+
("sailor", "harbor", "miner", "tunnel"),
|
| 405 |
+
("scientist", "laboratory", "gardener", "greenhouse"),
|
| 406 |
+
("translator", "dictionary", "navigator", "chart"),
|
| 407 |
+
("coach", "sideline", "chef", "kitchen"),
|
| 408 |
+
("astronaut", "capsule", "diver", "reef"),
|
| 409 |
+
]
|
| 410 |
+
for left_subject, left_object, right_subject, right_object in analogies_train:
|
| 411 |
+
add_train(
|
| 412 |
+
"analogy",
|
| 413 |
+
f"<reason> {left_subject} relates to {left_object} as {right_subject} relates to <answer>",
|
| 414 |
+
right_object,
|
| 415 |
+
sample=left_subject in {"bird", "doctor", "judge", "pilot", "chemist", "nurse"},
|
| 416 |
+
)
|
| 417 |
+
for left_subject, left_object, right_subject, right_object in analogies_holdout:
|
| 418 |
+
add_holdout(
|
| 419 |
+
"analogy",
|
| 420 |
+
f"<reason> {left_subject} relates to {left_object} as {right_subject} relates to <answer>",
|
| 421 |
+
right_object,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
classifications = [
|
| 425 |
+
("sparrow", "bird"),
|
| 426 |
+
("salmon", "fish"),
|
| 427 |
+
("oak", "tree"),
|
| 428 |
+
("rose", "flower"),
|
| 429 |
+
("copper", "metal"),
|
| 430 |
+
("mercury", "planet"),
|
| 431 |
+
("triangle", "shape"),
|
| 432 |
+
("python", "language"),
|
| 433 |
+
("whale", "mammal"),
|
| 434 |
+
("eagle", "bird"),
|
| 435 |
+
("lion", "mammal"),
|
| 436 |
+
("emerald", "gem"),
|
| 437 |
+
("neptune", "planet"),
|
| 438 |
+
("ruby", "gem"),
|
| 439 |
+
("cedar", "tree"),
|
| 440 |
+
("falcon", "bird"),
|
| 441 |
+
("orca", "mammal"),
|
| 442 |
+
("sapphire", "gem"),
|
| 443 |
+
("elm", "tree"),
|
| 444 |
+
("swift", "language"),
|
| 445 |
+
]
|
| 446 |
+
for item, group in classifications:
|
| 447 |
+
add_train(
|
| 448 |
+
"classification",
|
| 449 |
+
f"<memory> category of {item} is <answer>",
|
| 450 |
+
group,
|
| 451 |
+
sample=item in {"sparrow", "salmon", "oak", "rose", "neptune", "ruby"},
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
reasoning_phrases = [
|
| 455 |
+
("think clearly before final response", "response"),
|
| 456 |
+
("verify each claim before answer", "answer"),
|
| 457 |
+
("retrieve memory before conclusion", "conclusion"),
|
| 458 |
+
("focus on evidence before claim", "claim"),
|
| 459 |
+
("plan then reason then answer", "answer"),
|
| 460 |
+
("reflect before committing output", "output"),
|
| 461 |
+
("use memory when context grows", "grows"),
|
| 462 |
+
("check arithmetic before assertion", "assertion"),
|
| 463 |
+
("organize steps before conclusion", "conclusion"),
|
| 464 |
+
("inspect state before next answer", "answer"),
|
| 465 |
+
("paraphrase before claiming novelty", "novelty"),
|
| 466 |
+
("stabilize state before long generation", "generation"),
|
| 467 |
+
("reuse evidence before rewriting summary", "summary"),
|
| 468 |
+
("compare patterns before final synthesis", "synthesis"),
|
| 469 |
+
]
|
| 470 |
+
for phrase, final_word in reasoning_phrases:
|
| 471 |
+
add_train(
|
| 472 |
+
"protocol",
|
| 473 |
+
f"<reason> {phrase} <answer>",
|
| 474 |
+
final_word,
|
| 475 |
+
sample=final_word in {"response", "answer", "claim", "generation", "summary"},
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
paraphrase_train = [
|
| 479 |
+
(
|
| 480 |
+
"clear goals and steady practice",
|
| 481 |
+
"clear goals joined with steady practice create durable skill",
|
| 482 |
+
),
|
| 483 |
+
(
|
| 484 |
+
"careful review prevents shallow errors",
|
| 485 |
+
"careful review stops shallow errors before they spread",
|
| 486 |
+
),
|
| 487 |
+
(
|
| 488 |
+
"patient systems improve over time",
|
| 489 |
+
"patient systems improve through steady revision over time",
|
| 490 |
+
),
|
| 491 |
+
(
|
| 492 |
+
"bright ideas need exact execution",
|
| 493 |
+
"bright ideas need exact execution to become reliable work",
|
| 494 |
+
),
|
| 495 |
+
(
|
| 496 |
+
"quiet focus strengthens difficult reasoning",
|
| 497 |
+
"quiet focus strengthens difficult reasoning during long analysis",
|
| 498 |
+
),
|
| 499 |
+
(
|
| 500 |
+
"small evidence guides better judgment",
|
| 501 |
+
"small evidence guides better judgment when choices feel similar",
|
| 502 |
+
),
|
| 503 |
+
(
|
| 504 |
+
"stable memory helps long writing",
|
| 505 |
+
"stable memory helps long writing keep its shape and intent",
|
| 506 |
+
),
|
| 507 |
+
(
|
| 508 |
+
"measured iteration protects quality",
|
| 509 |
+
"measured iteration protects quality while keeping momentum alive",
|
| 510 |
+
),
|
| 511 |
+
(
|
| 512 |
+
"careful structure scales ambitious work",
|
| 513 |
+
"careful structure scales ambitious work without needless disorder",
|
| 514 |
+
),
|
| 515 |
+
(
|
| 516 |
+
"strong prompts need grounded answers",
|
| 517 |
+
"strong prompts need grounded answers supported by real evidence",
|
| 518 |
+
),
|
| 519 |
+
(
|
| 520 |
+
"shared context reduces wasted motion",
|
| 521 |
+
"shared context reduces wasted motion across a complex build",
|
| 522 |
+
),
|
| 523 |
+
(
|
| 524 |
+
"consistent language sharpens collaboration",
|
| 525 |
+
"consistent language sharpens collaboration and shortens confusion",
|
| 526 |
+
),
|
| 527 |
+
]
|
| 528 |
+
paraphrase_holdout = [
|
| 529 |
+
(
|
| 530 |
+
"steady systems reward patient builders",
|
| 531 |
+
"steady systems reward patient builders with dependable progress",
|
| 532 |
+
),
|
| 533 |
+
(
|
| 534 |
+
"clear revision protects difficult projects",
|
| 535 |
+
"clear revision protects difficult projects from hidden drift",
|
| 536 |
+
),
|
| 537 |
+
(
|
| 538 |
+
"focused memory improves long responses",
|
| 539 |
+
"focused memory improves long responses during deep reasoning",
|
| 540 |
+
),
|
| 541 |
+
(
|
| 542 |
+
"clean evidence supports honest claims",
|
| 543 |
+
"clean evidence supports honest claims during ambitious work",
|
| 544 |
+
),
|
| 545 |
+
(
|
| 546 |
+
"durable plans reduce fragile execution",
|
| 547 |
+
"durable plans reduce fragile execution before launch pressure rises",
|
| 548 |
+
),
|
| 549 |
+
(
|
| 550 |
+
"careful synthesis strengthens global understanding",
|
| 551 |
+
"careful synthesis strengthens global understanding without empty hype",
|
| 552 |
+
),
|
| 553 |
+
]
|
| 554 |
+
for source, target in paraphrase_train:
|
| 555 |
+
add_train(
|
| 556 |
+
"paraphrase",
|
| 557 |
+
f"<reason> paraphrase {source} into stronger prose <answer>",
|
| 558 |
+
target,
|
| 559 |
+
sample=source in {
|
| 560 |
+
"clear goals and steady practice",
|
| 561 |
+
"patient systems improve over time",
|
| 562 |
+
"stable memory helps long writing",
|
| 563 |
+
"shared context reduces wasted motion",
|
| 564 |
+
},
|
| 565 |
+
)
|
| 566 |
+
for source, target in paraphrase_holdout:
|
| 567 |
+
add_holdout(
|
| 568 |
+
"paraphrase",
|
| 569 |
+
f"<reason> paraphrase {source} into stronger prose <answer>",
|
| 570 |
+
target,
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
comparison_train = [
|
| 574 |
+
("pebble", "stone", "boulder", "largest", "boulder"),
|
| 575 |
+
("stream", "river", "ocean", "largest", "ocean"),
|
| 576 |
+
("candle", "lantern", "sun", "brightest", "sun"),
|
| 577 |
+
("village", "city", "continent", "largest", "continent"),
|
| 578 |
+
("breeze", "wind", "storm", "strongest", "storm"),
|
| 579 |
+
("cup", "bucket", "reservoir", "largest", "reservoir"),
|
| 580 |
+
("violin", "orchestra", "stadium chorus", "loudest", "stadium chorus"),
|
| 581 |
+
("ember", "flame", "wildfire", "hottest", "wildfire"),
|
| 582 |
+
("minute", "hour", "day", "longest", "day"),
|
| 583 |
+
("thread", "rope", "bridge cable", "thickest", "bridge cable"),
|
| 584 |
+
("hill", "mountain", "range", "largest", "range"),
|
| 585 |
+
("drizzle", "rain", "monsoon", "strongest", "monsoon"),
|
| 586 |
+
("spark", "torch", "beacon", "brightest", "beacon"),
|
| 587 |
+
("brook", "canal", "delta", "widest", "delta"),
|
| 588 |
+
("hut", "house", "tower", "tallest", "tower"),
|
| 589 |
+
("cart", "truck", "freighter", "largest", "freighter"),
|
| 590 |
+
("path", "road", "highway", "widest", "highway"),
|
| 591 |
+
("note", "melody", "symphony", "longest", "symphony"),
|
| 592 |
+
]
|
| 593 |
+
comparison_holdout = [
|
| 594 |
+
("seed", "sapling", "forest", "largest", "forest"),
|
| 595 |
+
("glimmer", "lamp", "lighthouse", "brightest", "lighthouse"),
|
| 596 |
+
("whisper", "speech", "thunder", "loudest", "thunder"),
|
| 597 |
+
("creek", "river", "sea", "largest", "sea"),
|
| 598 |
+
("trail", "road", "expressway", "widest", "expressway"),
|
| 599 |
+
("hill", "cliff", "summit", "highest", "summit"),
|
| 600 |
+
("ember", "bonfire", "volcano", "hottest", "volcano"),
|
| 601 |
+
("minute", "season", "century", "longest", "century"),
|
| 602 |
+
]
|
| 603 |
+
for first, second, third, comparator, expected in comparison_train:
|
| 604 |
+
add_train(
|
| 605 |
+
"comparison",
|
| 606 |
+
f"<reason> {comparator} among {first} {second} {third} is <answer>",
|
| 607 |
+
expected,
|
| 608 |
+
sample=expected in {"boulder", "ocean", "storm", "day", "range", "highway"},
|
| 609 |
+
)
|
| 610 |
+
for first, second, third, comparator, expected in comparison_holdout:
|
| 611 |
+
add_holdout(
|
| 612 |
+
"comparison",
|
| 613 |
+
f"<reason> {comparator} among {first} {second} {third} is <answer>",
|
| 614 |
+
expected,
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
causal_train = [
|
| 618 |
+
("iron left in rain", "rust"),
|
| 619 |
+
("clouds cooling into droplets", "rain"),
|
| 620 |
+
("plants receiving sunlight", "growth"),
|
| 621 |
+
("water reaching freezing temperature", "ice"),
|
| 622 |
+
("friction between dry sticks", "heat"),
|
| 623 |
+
("strong wind over warm water", "waves"),
|
| 624 |
+
("seed placed in moist soil", "sprout"),
|
| 625 |
+
("glass exposed to sudden force", "crack"),
|
| 626 |
+
("constant pressure on stone", "erosion"),
|
| 627 |
+
("fuel meeting flame", "combustion"),
|
| 628 |
+
("repeated practice with feedback", "skill"),
|
| 629 |
+
("unchecked heat in metal", "expansion"),
|
| 630 |
+
("low temperature overnight", "frost"),
|
| 631 |
+
("sustained current through filament", "glow"),
|
| 632 |
+
("gravity pulling rain downhill", "flow"),
|
| 633 |
+
("sleep loss across many nights", "fatigue"),
|
| 634 |
+
("overloaded bridge cable", "strain"),
|
| 635 |
+
("salt water meeting steel", "corrosion"),
|
| 636 |
+
]
|
| 637 |
+
causal_holdout = [
|
| 638 |
+
("dust gathering in still air", "settling"),
|
| 639 |
+
("long drought across dry fields", "cracking"),
|
| 640 |
+
("steady pressure beneath ice", "creep"),
|
| 641 |
+
("clean lens focusing sunlight", "heat"),
|
| 642 |
+
("lack of oxygen in closed flame", "extinguish"),
|
| 643 |
+
("waves striking rock for years", "wear"),
|
| 644 |
+
]
|
| 645 |
+
for cause, effect in causal_train:
|
| 646 |
+
add_train(
|
| 647 |
+
"causal",
|
| 648 |
+
f"<reason> effect of {cause} is <answer>",
|
| 649 |
+
effect,
|
| 650 |
+
sample=effect in {"rust", "rain", "growth", "ice", "skill", "fatigue"},
|
| 651 |
+
)
|
| 652 |
+
for cause, effect in causal_holdout:
|
| 653 |
+
add_holdout(
|
| 654 |
+
"causal",
|
| 655 |
+
f"<reason> effect of {cause} is <answer>",
|
| 656 |
+
effect,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
definition_train = [
|
| 660 |
+
("orbit", "path traced by one body around another"),
|
| 661 |
+
("bridge", "structure that carries passage over an obstacle"),
|
| 662 |
+
("catalyst", "substance that speeds a reaction without being consumed"),
|
| 663 |
+
("harbor", "protected water area where ships can anchor safely"),
|
| 664 |
+
("algorithm", "finite procedure for transforming input into output"),
|
| 665 |
+
("archive", "ordered collection preserved for future reference"),
|
| 666 |
+
("equilibrium", "state where opposing influences remain balanced"),
|
| 667 |
+
("lens", "curved material that focuses or spreads light"),
|
| 668 |
+
("reservoir", "stored supply of water or another resource"),
|
| 669 |
+
("signal", "pattern that carries information across distance"),
|
| 670 |
+
("compiler", "program that translates source code into another form"),
|
| 671 |
+
("calendar", "system for organizing days into meaningful cycles"),
|
| 672 |
+
("estuary", "place where river water meets the sea"),
|
| 673 |
+
("voltage", "difference in electric potential between two points"),
|
| 674 |
+
("synapse", "junction where one neuron communicates with another"),
|
| 675 |
+
("telescope", "instrument that gathers distant light for observation"),
|
| 676 |
+
]
|
| 677 |
+
definition_holdout = [
|
| 678 |
+
("glacier", "mass of ice that moves slowly across land"),
|
| 679 |
+
("protocol", "agreed procedure that coordinates reliable exchange"),
|
| 680 |
+
("reef", "ridge of rock or coral rising near the water surface"),
|
| 681 |
+
("memory", "stored information available for later retrieval"),
|
| 682 |
+
("frequency", "how often a repeating event occurs in set time"),
|
| 683 |
+
("compass", "instrument that indicates direction relative to north"),
|
| 684 |
+
]
|
| 685 |
+
for term, definition in definition_train:
|
| 686 |
+
add_train(
|
| 687 |
+
"definition",
|
| 688 |
+
f"<memory> define {term} as <answer>",
|
| 689 |
+
definition,
|
| 690 |
+
sample=term in {"orbit", "algorithm", "compiler", "harbor", "signal"},
|
| 691 |
+
)
|
| 692 |
+
for term, definition in definition_holdout:
|
| 693 |
+
add_holdout(
|
| 694 |
+
"definition",
|
| 695 |
+
f"<memory> define {term} as <answer>",
|
| 696 |
+
definition,
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
identity_train = [
|
| 700 |
+
(
|
| 701 |
+
"describe REFRAMR briefly",
|
| 702 |
+
"REFRAMR is an analytical recurrent language system built by OkeyMeta Ltd to compute structure from corpus evidence instead of gradient loops.",
|
| 703 |
+
),
|
| 704 |
+
(
|
| 705 |
+
"describe REFRAMR in your own words",
|
| 706 |
+
"REFRAMR is OkeyMeta Ltd language intelligence shaped through analytical memory recurrent state and computed structure rather than opaque training ritual.",
|
| 707 |
+
),
|
| 708 |
+
(
|
| 709 |
+
"describe REFRAMR in your own words with punctuation",
|
| 710 |
+
"REFRAMR is recurrent, analytical, and evidence-driven; OkeyMeta Ltd shapes it to compute structure from corpus behavior instead of blind gradient churn.",
|
| 711 |
+
),
|
| 712 |
+
(
|
| 713 |
+
"describe REFRAMR in your own words, with punctuation",
|
| 714 |
+
"REFRAMR is a recurrent analytical language system; OkeyMeta Ltd builds it to preserve structure, carry long context, and keep reasoning signals inspectable.",
|
| 715 |
+
),
|
| 716 |
+
(
|
| 717 |
+
"what is REFRAMR",
|
| 718 |
+
"REFRAMR is an OkeyMeta analytical language system built around computed memory state and closed form readout.",
|
| 719 |
+
),
|
| 720 |
+
(
|
| 721 |
+
"what makes REFRAMR different",
|
| 722 |
+
"REFRAMR differs by combining analytical memory corpus statistics and transparent reasoning traces without standard backprop training",
|
| 723 |
+
),
|
| 724 |
+
(
|
| 725 |
+
"describe FrameToken briefly",
|
| 726 |
+
"FrameToken is REFRAMR native tokenizer from OkeyMeta Ltd that preserves reasoning controls while staying fast on ordinary hardware.",
|
| 727 |
+
),
|
| 728 |
+
(
|
| 729 |
+
"what is REFRAMR mission",
|
| 730 |
+
"REFRAMR aims to build strong language intelligence through computed structure recurrent memory and interpretable reasoning",
|
| 731 |
+
),
|
| 732 |
+
(
|
| 733 |
+
"how does REFRAMR reason",
|
| 734 |
+
"REFRAMR reasons through recurrent state analytical retrieval transition priors and explicit control tokens",
|
| 735 |
+
),
|
| 736 |
+
(
|
| 737 |
+
"what is REFRAMR memory",
|
| 738 |
+
"REFRAMR memory is a multi timescale analytical state that compresses history without quadratic attention.",
|
| 739 |
+
),
|
| 740 |
+
(
|
| 741 |
+
"explain REFRAMR memory for long context",
|
| 742 |
+
"REFRAMR memory keeps long context by folding prior evidence into a persistent analytical state so later tokens can still respond to earlier structure.",
|
| 743 |
+
),
|
| 744 |
+
(
|
| 745 |
+
"explain REFRAMR memory for long context in your own words",
|
| 746 |
+
"REFRAMR keeps long context through a persistent analytical memory state, so earlier structure can still shape later output without a quadratic attention map.",
|
| 747 |
+
),
|
| 748 |
+
(
|
| 749 |
+
"describe REFRAMR long context memory",
|
| 750 |
+
"REFRAMR long context memory is a persistent recurrent state that carries history forward without storing every token in a quadratic map.",
|
| 751 |
+
),
|
| 752 |
+
(
|
| 753 |
+
"what is REFRAMR readout",
|
| 754 |
+
"REFRAMR readout is a closed form mapping from analytical state to token probabilities.",
|
| 755 |
+
),
|
| 756 |
+
(
|
| 757 |
+
"what does REFRAMR optimize for",
|
| 758 |
+
"REFRAMR optimizes for analytical transparency long context behavior and hardware accessible computation",
|
| 759 |
+
),
|
| 760 |
+
(
|
| 761 |
+
"what is REFRAMR tokenizer",
|
| 762 |
+
"REFRAMR tokenizer is FrameToken a native OkeyMeta vocabulary system shaped for analytical recurrent generation",
|
| 763 |
+
),
|
| 764 |
+
(
|
| 765 |
+
"who are you REFRAMR",
|
| 766 |
+
"I am REFRAMR an OkeyMeta analytical language system shaped by corpus structure and transparent reasoning",
|
| 767 |
+
),
|
| 768 |
+
(
|
| 769 |
+
"what is REFRAMR voice",
|
| 770 |
+
"REFRAMR voice is deliberate evidence driven and structurally aware rather than shallow imitation",
|
| 771 |
+
),
|
| 772 |
+
(
|
| 773 |
+
"who builds REFRAMR",
|
| 774 |
+
"REFRAMR is built by OkeyMeta Ltd as a recurrent analytical language system for long context reasoning.",
|
| 775 |
+
),
|
| 776 |
+
(
|
| 777 |
+
"summarize OkeyMeta role in REFRAMR",
|
| 778 |
+
"OkeyMeta Ltd builds REFRAMR as a transparent analytical language system grounded in corpus structure and recurrent memory",
|
| 779 |
+
),
|
| 780 |
+
(
|
| 781 |
+
"what is OkeyMeta mission for REFRAMR",
|
| 782 |
+
"OkeyMeta Ltd is building REFRAMR to turn analytical structure into practical language intelligence on ordinary hardware",
|
| 783 |
+
),
|
| 784 |
+
(
|
| 785 |
+
"describe REFRAMR with punctuation",
|
| 786 |
+
"REFRAMR is analytical, recurrent, and deliberate; OkeyMeta Ltd builds it to compute structure from evidence, not gradient ritual.",
|
| 787 |
+
),
|
| 788 |
+
(
|
| 789 |
+
"summarize REFRAMR with punctuation",
|
| 790 |
+
"REFRAMR is a recurrent analytical language system; OkeyMeta Ltd builds it to keep structure visible, context persistent, and compute practical.",
|
| 791 |
+
),
|
| 792 |
+
(
|
| 793 |
+
"summarize FrameToken with punctuation",
|
| 794 |
+
"FrameToken preserves boundaries, protects control tokens, and stays portable; it gives REFRAMR a clean native interface.",
|
| 795 |
+
),
|
| 796 |
+
]
|
| 797 |
+
identity_holdout = [
|
| 798 |
+
(
|
| 799 |
+
"explain REFRAMR in one sentence",
|
| 800 |
+
"REFRAMR is an OkeyMeta analytical language system that computes structure from corpus statistics and explicit memory dynamics",
|
| 801 |
+
),
|
| 802 |
+
(
|
| 803 |
+
"summarize REFRAMR identity",
|
| 804 |
+
"REFRAMR is an OkeyMeta analytical recurrent model built to reason with transparent state rather than opaque gradient rituals",
|
| 805 |
+
),
|
| 806 |
+
(
|
| 807 |
+
"what kind of model is REFRAMR",
|
| 808 |
+
"REFRAMR is an OkeyMeta post transformer recurrent analytical language model focused on computed structure and long stateful reasoning",
|
| 809 |
+
),
|
| 810 |
+
(
|
| 811 |
+
"describe REFRAMR purpose",
|
| 812 |
+
"REFRAMR exists to turn mathematical structure and recurrent memory into practical language intelligence",
|
| 813 |
+
),
|
| 814 |
+
(
|
| 815 |
+
"who owns REFRAMR",
|
| 816 |
+
"REFRAMR is built and owned by OkeyMeta Ltd as a long context analytical language effort",
|
| 817 |
+
),
|
| 818 |
+
(
|
| 819 |
+
"describe FrameToken role",
|
| 820 |
+
"FrameToken is REFRAMR native tokenizer designed by OkeyMeta Ltd for analytical recurrent generation",
|
| 821 |
+
),
|
| 822 |
+
(
|
| 823 |
+
"explain REFRAMR with punctuation",
|
| 824 |
+
"REFRAMR is recurrent, analytical, and long-context oriented; OkeyMeta Ltd built it to compute structure with visible reasoning.",
|
| 825 |
+
),
|
| 826 |
+
]
|
| 827 |
+
for prompt, answer in identity_train:
|
| 828 |
+
add_train(
|
| 829 |
+
"identity",
|
| 830 |
+
f"<reason> {prompt} <answer>",
|
| 831 |
+
answer,
|
| 832 |
+
sample=prompt in {
|
| 833 |
+
"describe REFRAMR briefly",
|
| 834 |
+
"what is REFRAMR",
|
| 835 |
+
"what makes REFRAMR different",
|
| 836 |
+
"describe FrameToken briefly",
|
| 837 |
+
"describe REFRAMR with punctuation",
|
| 838 |
+
},
|
| 839 |
+
)
|
| 840 |
+
for prompt, answer in identity_holdout:
|
| 841 |
+
add_holdout(
|
| 842 |
+
"identity",
|
| 843 |
+
f"<reason> {prompt} <answer>",
|
| 844 |
+
answer,
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
exposition_train = [
|
| 848 |
+
(
|
| 849 |
+
"explain why long context matters",
|
| 850 |
+
"Long context matters because ideas unfold across distance: setup, consequence, and revision rarely live in one sentence. A strong recurrent system must carry structure forward, not just local echoes.",
|
| 851 |
+
),
|
| 852 |
+
(
|
| 853 |
+
"explain why punctuation matters in language models",
|
| 854 |
+
"Punctuation carries structure, pace, and intent; commas slow rhythm, periods close claims, and colons prepare explanation. A model that ignores marks will often flatten meaning.",
|
| 855 |
+
),
|
| 856 |
+
(
|
| 857 |
+
"explain how punctuation helps long reasoning",
|
| 858 |
+
"Punctuation helps long reasoning because sequence alone is not enough: commas stage detail, semicolons balance linked claims, and periods let one conclusion land before the next begins.",
|
| 859 |
+
),
|
| 860 |
+
(
|
| 861 |
+
"explain why punctuation supports long context",
|
| 862 |
+
"Punctuation supports long context by keeping long passages segmented and recoverable. When clauses stay marked, memory can preserve relation, pause, and closure more reliably.",
|
| 863 |
+
),
|
| 864 |
+
(
|
| 865 |
+
"explain why punctuation helps long reasoning",
|
| 866 |
+
"Punctuation helps long reasoning by separating steps, slowing transitions, and protecting closure. Commas meter detail, colons open explanation, and periods keep one claim from smearing into the next.",
|
| 867 |
+
),
|
| 868 |
+
(
|
| 869 |
+
"outline REFRAMR workflow",
|
| 870 |
+
"REFRAMR follows a clean path: build corpus statistics, derive recurrent state behavior, and compute the readout. Each stage stays inspectable; none requires opaque epoch loops.",
|
| 871 |
+
),
|
| 872 |
+
(
|
| 873 |
+
"explain OkeyMeta design ethic",
|
| 874 |
+
"OkeyMeta design ethic is practical and strict: keep provenance visible, keep compute sane, and keep the system understandable. Ambition matters, but clarity matters more.",
|
| 875 |
+
),
|
| 876 |
+
(
|
| 877 |
+
"explain why evidence matters",
|
| 878 |
+
"Evidence matters because confidence alone is cheap; structure, tests, and reproducible runs make a claim durable. When evidence improves, judgment becomes steadier.",
|
| 879 |
+
),
|
| 880 |
+
(
|
| 881 |
+
"describe analytical memory",
|
| 882 |
+
"Analytical memory compresses history into a reusable state; it does not replay every token. That compression is useful only when the state stays orderly, expressive, and inspectable.",
|
| 883 |
+
),
|
| 884 |
+
(
|
| 885 |
+
"explain corpus quality",
|
| 886 |
+
"Corpus quality is not only scale: it is structure, range, and cleanliness. Better data teaches a model where to pause, when to compare, and how to finish a thought.",
|
| 887 |
+
),
|
| 888 |
+
(
|
| 889 |
+
"explain transparent reasoning",
|
| 890 |
+
"Transparent reasoning does not mean leaking private scratch work; it means exposing useful signals, clear traces, and grounded summaries. The system should reveal why a path dominated.",
|
| 891 |
+
),
|
| 892 |
+
(
|
| 893 |
+
"describe disciplined generalization",
|
| 894 |
+
"Disciplined generalization begins with pattern depth, not shallow imitation. A model should reuse structure carefully, vary language naturally, and stay anchored to evidence.",
|
| 895 |
+
),
|
| 896 |
+
(
|
| 897 |
+
"explain why recurrent state can scale",
|
| 898 |
+
"Recurrent state can scale because it updates incrementally; it does not rebuild a full attention map at each step. The challenge is quality, not merely length.",
|
| 899 |
+
),
|
| 900 |
+
(
|
| 901 |
+
"describe strong completion behavior",
|
| 902 |
+
"Strong completion behavior means the answer reaches a real ending: clauses resolve, punctuation lands, and drift stays contained. A half-finished sentence is not intelligence.",
|
| 903 |
+
),
|
| 904 |
+
(
|
| 905 |
+
"explain why handcrafted data still matters",
|
| 906 |
+
"Handcrafted data still matters because it can encode precision, tone, and deliberate contrast. It supplies patterns that scraped noise often blurs or discards.",
|
| 907 |
+
),
|
| 908 |
+
(
|
| 909 |
+
"explain why punctuation supports long answers",
|
| 910 |
+
"Punctuation supports long answers because structure must breathe: commas pace detail, semicolons balance related claims, and periods secure closure. Without marks, long prose often collapses into blur.",
|
| 911 |
+
),
|
| 912 |
+
(
|
| 913 |
+
"describe healthy model discipline",
|
| 914 |
+
"Healthy model discipline is visible in the small things: exact wording, stable endings, measured confidence, and clean recovery from ambiguity. Strong systems respect detail before spectacle.",
|
| 915 |
+
),
|
| 916 |
+
(
|
| 917 |
+
"explain why broad corpus style matters",
|
| 918 |
+
"Broad corpus style matters because the model learns more than facts; it learns transition, emphasis, cadence, and restraint. A rich corpus teaches how to move from premise to finish.",
|
| 919 |
+
),
|
| 920 |
+
(
|
| 921 |
+
"describe how evidence and style should meet",
|
| 922 |
+
"Evidence and style should meet in one sentence: the claim must be accurate, and the sentence must be shaped well enough to carry that accuracy without friction. Good language engineering serves both.",
|
| 923 |
+
),
|
| 924 |
+
(
|
| 925 |
+
"explain why exact retrieval still needs composition",
|
| 926 |
+
"Exact retrieval still needs composition because recovered facts must land in coherent prose; the answer should connect, not merely appear. Precision becomes more useful when it arrives with structure.",
|
| 927 |
+
),
|
| 928 |
+
(
|
| 929 |
+
"outline why model endings matter",
|
| 930 |
+
"Model endings matter for a simple reason: the final clause teaches whether the system understood the task or only imitated momentum. A clean ending shows control, not luck.",
|
| 931 |
+
),
|
| 932 |
+
]
|
| 933 |
+
exposition_holdout = [
|
| 934 |
+
(
|
| 935 |
+
"explain why sentence endings matter",
|
| 936 |
+
"Sentence endings matter because closure guides interpretation; a period settles a claim, while a comma signals more is coming. Good models must feel that difference.",
|
| 937 |
+
),
|
| 938 |
+
(
|
| 939 |
+
"explain why structured data improves writing",
|
| 940 |
+
"Structured data improves writing because it teaches ordering, emphasis, and transition; the model learns not only facts, but how claims should connect.",
|
| 941 |
+
),
|
| 942 |
+
(
|
| 943 |
+
"outline why analytical systems need traces",
|
| 944 |
+
"Analytical systems need traces so operators can inspect dominant signals, compare retrieval paths, and debug drift. Visibility turns mystery into engineering.",
|
| 945 |
+
),
|
| 946 |
+
(
|
| 947 |
+
"describe why punctuation supports reasoning",
|
| 948 |
+
"Punctuation supports reasoning by marking relation, pause, and hierarchy; it helps the reader separate evidence from conclusion. A fluent model should use marks intentionally.",
|
| 949 |
+
),
|
| 950 |
+
(
|
| 951 |
+
"explain why corpus range matters",
|
| 952 |
+
"Corpus range matters because generalization grows from varied structures, not one narrow script. When prompts diversify, the model learns to pivot with control.",
|
| 953 |
+
),
|
| 954 |
+
(
|
| 955 |
+
"describe why exact answers still need style",
|
| 956 |
+
"Exact answers still need style: the right fact should arrive with clean syntax, useful pacing, and a stable finish. Precision and fluency should reinforce each other.",
|
| 957 |
+
),
|
| 958 |
+
]
|
| 959 |
+
for prompt, answer in exposition_train:
|
| 960 |
+
add_train(
|
| 961 |
+
"exposition",
|
| 962 |
+
f"<reason> {prompt} <answer>",
|
| 963 |
+
answer,
|
| 964 |
+
sample=prompt in {
|
| 965 |
+
"explain why long context matters",
|
| 966 |
+
"explain why punctuation matters in language models",
|
| 967 |
+
"outline REFRAMR workflow",
|
| 968 |
+
"describe strong completion behavior",
|
| 969 |
+
},
|
| 970 |
+
)
|
| 971 |
+
for prompt, answer in exposition_holdout:
|
| 972 |
+
add_holdout(
|
| 973 |
+
"exposition",
|
| 974 |
+
f"<reason> {prompt} <answer>",
|
| 975 |
+
answer,
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
composition_train = [
|
| 979 |
+
(
|
| 980 |
+
"ocean",
|
| 981 |
+
"ocean waves move with patient rhythm and silver foam follows the moonlit shore while distant wind keeps a calm measured pulse",
|
| 982 |
+
),
|
| 983 |
+
(
|
| 984 |
+
"forest",
|
| 985 |
+
"forest light falls softly through cedar branches and cool air carries resin and rain while the ground stays quiet beneath careful steps",
|
| 986 |
+
),
|
| 987 |
+
(
|
| 988 |
+
"desert",
|
| 989 |
+
"desert heat bends above pale stone and long shadows stretch across patient sand while evening air slowly restores a gentler balance",
|
| 990 |
+
),
|
| 991 |
+
(
|
| 992 |
+
"city",
|
| 993 |
+
"city dawn spills across glass towers and quiet streets as buses wake in sequence and windows catch a thin ribbon of gold",
|
| 994 |
+
),
|
| 995 |
+
(
|
| 996 |
+
"mountain",
|
| 997 |
+
"mountain air stays bright and thin while granite faces hold the morning sun and distant rivers thread silver lines below",
|
| 998 |
+
),
|
| 999 |
+
(
|
| 1000 |
+
"harbor",
|
| 1001 |
+
"harbor lights shimmer in patient water while cables rest against masts and slow bells mark the edge of another working night",
|
| 1002 |
+
),
|
| 1003 |
+
(
|
| 1004 |
+
"library",
|
| 1005 |
+
"library silence gathers around tall shelves while lamps hold warm circles of light and every page waits with deliberate calm",
|
| 1006 |
+
),
|
| 1007 |
+
(
|
| 1008 |
+
"laboratory",
|
| 1009 |
+
"laboratory glass reflects a quiet blue glow while instruments rest in ordered rows and each surface signals exact preparation",
|
| 1010 |
+
),
|
| 1011 |
+
(
|
| 1012 |
+
"garden",
|
| 1013 |
+
"garden air carries wet soil and green fragrance while trimmed paths divide the beds and new petals lean toward morning light",
|
| 1014 |
+
),
|
| 1015 |
+
(
|
| 1016 |
+
"observatory",
|
| 1017 |
+
"observatory domes open toward dark sky while motors turn with patient certainty and cold metal frames the waiting stars",
|
| 1018 |
+
),
|
| 1019 |
+
]
|
| 1020 |
+
composition_holdout = [
|
| 1021 |
+
(
|
| 1022 |
+
"glacier",
|
| 1023 |
+
"glacier light drifts across slow blue ice while distant air remains clear and every ridge keeps a restrained patient shine",
|
| 1024 |
+
),
|
| 1025 |
+
(
|
| 1026 |
+
"volcano",
|
| 1027 |
+
"volcano stone holds the memory of fire while dark slopes remain still and rising heat bends the horizon with slow force",
|
| 1028 |
+
),
|
| 1029 |
+
(
|
| 1030 |
+
"cathedral",
|
| 1031 |
+
"cathedral windows gather colored light while high arches hold a quiet echo and polished stone returns each careful footstep",
|
| 1032 |
+
),
|
| 1033 |
+
(
|
| 1034 |
+
"market",
|
| 1035 |
+
"market voices braid with morning movement while bright fruit lines the tables and woven shade softens the noonward heat",
|
| 1036 |
+
),
|
| 1037 |
+
(
|
| 1038 |
+
"reef",
|
| 1039 |
+
"reef water carries shifting bands of color while coral forms patient cities and bright fish stitch motion through clear blue lanes",
|
| 1040 |
+
),
|
| 1041 |
+
(
|
| 1042 |
+
"station",
|
| 1043 |
+
"station metal hums beneath pale lamps while distant tracks hold a thin vibration and travelers wait inside orderly lines",
|
| 1044 |
+
),
|
| 1045 |
+
(
|
| 1046 |
+
"courtroom",
|
| 1047 |
+
"courtroom wood carries a formal hush while measured voices rise with care and every pause sharpens the weight of the next sentence",
|
| 1048 |
+
),
|
| 1049 |
+
(
|
| 1050 |
+
"shipyard",
|
| 1051 |
+
"shipyard steel rings through salted air while cranes turn with slow authority and sparks drift briefly before fading into dusk",
|
| 1052 |
+
),
|
| 1053 |
+
(
|
| 1054 |
+
"archive",
|
| 1055 |
+
"archive boxes rest in numbered rows while cool air holds the paper scent and each label promises a patient return to memory",
|
| 1056 |
+
),
|
| 1057 |
+
(
|
| 1058 |
+
"savanna",
|
| 1059 |
+
"savanna light stretches across dry grass while distant heat softens the horizon and watchful movement gathers near the last shade",
|
| 1060 |
+
),
|
| 1061 |
+
(
|
| 1062 |
+
"workshop",
|
| 1063 |
+
"workshop lamps shine over ordered tools while sawdust settles in pale ribbons and each bench waits for deliberate hands",
|
| 1064 |
+
),
|
| 1065 |
+
(
|
| 1066 |
+
"bridge",
|
| 1067 |
+
"bridge cables hold their tense geometry while river light drifts below and the roadway hums with disciplined forward motion",
|
| 1068 |
+
),
|
| 1069 |
+
]
|
| 1070 |
+
for theme, answer in composition_train:
|
| 1071 |
+
add_train(
|
| 1072 |
+
"composition",
|
| 1073 |
+
f"<reason> write {theme} scene in one paragraph <answer>",
|
| 1074 |
+
answer,
|
| 1075 |
+
sample=theme in {"ocean", "forest", "city", "harbor", "laboratory"},
|
| 1076 |
+
)
|
| 1077 |
+
add_train(
|
| 1078 |
+
"composition",
|
| 1079 |
+
f"<reason> write {theme} scene <answer>",
|
| 1080 |
+
answer,
|
| 1081 |
+
sample=False,
|
| 1082 |
+
)
|
| 1083 |
+
for theme, answer in composition_holdout:
|
| 1084 |
+
add_holdout(
|
| 1085 |
+
"composition",
|
| 1086 |
+
f"<reason> write {theme} scene in one paragraph <answer>",
|
| 1087 |
+
answer,
|
| 1088 |
+
)
|
| 1089 |
+
add_holdout(
|
| 1090 |
+
"composition",
|
| 1091 |
+
f"<reason> write {theme} scene <answer>",
|
| 1092 |
+
answer,
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
add_open(
|
| 1096 |
+
"composition",
|
| 1097 |
+
"write harbor dawn scene with calm tension",
|
| 1098 |
+
[
|
| 1099 |
+
["harbor", "port"],
|
| 1100 |
+
["dawn", "morning", "sunrise", "light"],
|
| 1101 |
+
["water", "tide", "shore"],
|
| 1102 |
+
["calm", "quiet", "measured", "tension"],
|
| 1103 |
+
],
|
| 1104 |
+
banned_phrases=[
|
| 1105 |
+
"harbor lights shimmer in patient water while cables rest against masts and slow bells mark the edge of another working night",
|
| 1106 |
+
],
|
| 1107 |
+
min_words=16,
|
| 1108 |
+
max_tokens=40,
|
| 1109 |
+
)
|
| 1110 |
+
add_open(
|
| 1111 |
+
"composition",
|
| 1112 |
+
"write laboratory harbor scene with precise calm",
|
| 1113 |
+
[
|
| 1114 |
+
["laboratory", "glass", "instrument"],
|
| 1115 |
+
["harbor", "water", "mast", "cable"],
|
| 1116 |
+
["calm", "quiet", "precise", "ordered"],
|
| 1117 |
+
],
|
| 1118 |
+
banned_phrases=[],
|
| 1119 |
+
min_words=16,
|
| 1120 |
+
max_tokens=40,
|
| 1121 |
+
)
|
| 1122 |
+
add_open(
|
| 1123 |
+
"identity",
|
| 1124 |
+
"describe REFRAMR in your own words, with punctuation",
|
| 1125 |
+
[
|
| 1126 |
+
["reframr"],
|
| 1127 |
+
["okeymeta"],
|
| 1128 |
+
["analytical", "recurrent", "language", "system"],
|
| 1129 |
+
],
|
| 1130 |
+
banned_phrases=[
|
| 1131 |
+
"REFRAMR is an analytical recurrent language system built by OkeyMeta Ltd to compute structure from corpus evidence instead of gradient loops",
|
| 1132 |
+
"REFRAMR is analytical, recurrent, and deliberate; OkeyMeta Ltd builds it to compute structure from evidence, not gradient ritual.",
|
| 1133 |
+
],
|
| 1134 |
+
min_words=12,
|
| 1135 |
+
max_tokens=36,
|
| 1136 |
+
)
|
| 1137 |
+
add_open(
|
| 1138 |
+
"exposition",
|
| 1139 |
+
"explain why punctuation helps long reasoning",
|
| 1140 |
+
[
|
| 1141 |
+
["punctuation"],
|
| 1142 |
+
["reasoning", "thinking"],
|
| 1143 |
+
["structure", "pace", "pause", "closure"],
|
| 1144 |
+
],
|
| 1145 |
+
banned_phrases=[
|
| 1146 |
+
"Punctuation supports long answers because structure must breathe: commas pace detail, semicolons balance related claims, and periods secure closure. Without marks, long prose often collapses into blur.",
|
| 1147 |
+
],
|
| 1148 |
+
min_words=18,
|
| 1149 |
+
max_tokens=40,
|
| 1150 |
+
)
|
| 1151 |
+
add_open(
|
| 1152 |
+
"identity",
|
| 1153 |
+
"explain REFRAMR memory for long context in your own words",
|
| 1154 |
+
[
|
| 1155 |
+
["reframr"],
|
| 1156 |
+
["memory", "state"],
|
| 1157 |
+
["context", "history"],
|
| 1158 |
+
["long", "persistent", "extended"],
|
| 1159 |
+
],
|
| 1160 |
+
banned_phrases=[
|
| 1161 |
+
"REFRAMR memory is a multi timescale analytical state that compresses history without quadratic attention",
|
| 1162 |
+
],
|
| 1163 |
+
min_words=16,
|
| 1164 |
+
max_tokens=40,
|
| 1165 |
+
)
|
| 1166 |
+
add_open(
|
| 1167 |
+
"composition",
|
| 1168 |
+
"write archive bridge scene with reflective tension",
|
| 1169 |
+
[
|
| 1170 |
+
["archive", "paper", "label", "memory"],
|
| 1171 |
+
["bridge", "cable", "river", "roadway"],
|
| 1172 |
+
["reflective", "tension", "quiet", "measured"],
|
| 1173 |
+
],
|
| 1174 |
+
banned_phrases=[],
|
| 1175 |
+
min_words=16,
|
| 1176 |
+
max_tokens=40,
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
return CorpusPackage(
|
| 1180 |
+
name="FrameCorpus-Foundation-v2",
|
| 1181 |
+
records=records,
|
| 1182 |
+
section_counts=section_counts,
|
| 1183 |
+
memorization_samples=_balanced_samples(memorization, 24),
|
| 1184 |
+
generalization_samples=_balanced_samples(generalization, 16),
|
| 1185 |
+
open_ended_samples=open_ended,
|
| 1186 |
+
)
|
| 1187 |
+
|
| 1188 |
+
|
| 1189 |
+
def build_generalization_corpus() -> CorpusPackage:
|
| 1190 |
+
foundation = build_foundation_corpus()
|
| 1191 |
+
allowed_sections = {
|
| 1192 |
+
"analogy",
|
| 1193 |
+
"paraphrase",
|
| 1194 |
+
"comparison",
|
| 1195 |
+
"causal",
|
| 1196 |
+
"definition",
|
| 1197 |
+
"identity",
|
| 1198 |
+
"exposition",
|
| 1199 |
+
"composition",
|
| 1200 |
+
}
|
| 1201 |
+
|
| 1202 |
+
records = [
|
| 1203 |
+
record
|
| 1204 |
+
for record in foundation.records
|
| 1205 |
+
if record.section in allowed_sections
|
| 1206 |
+
]
|
| 1207 |
+
generalization = [
|
| 1208 |
+
sample
|
| 1209 |
+
for sample in foundation.generalization_samples
|
| 1210 |
+
if sample.section in allowed_sections
|
| 1211 |
+
]
|
| 1212 |
+
open_ended = [
|
| 1213 |
+
sample
|
| 1214 |
+
for sample in foundation.open_ended_samples
|
| 1215 |
+
if sample.section in allowed_sections
|
| 1216 |
+
]
|
| 1217 |
+
|
| 1218 |
+
return CorpusPackage(
|
| 1219 |
+
name="FrameCorpus-Generalization-v1",
|
| 1220 |
+
records=records,
|
| 1221 |
+
section_counts=_recount_sections(records),
|
| 1222 |
+
memorization_samples=[],
|
| 1223 |
+
generalization_samples=_balanced_samples(generalization, min(16, len(generalization))),
|
| 1224 |
+
open_ended_samples=open_ended,
|
| 1225 |
+
)
|
| 1226 |
+
|
| 1227 |
+
|
| 1228 |
+
def write_corpus_package(package: CorpusPackage, output_dir: str | Path) -> dict[str, str]:
|
| 1229 |
+
directory = Path(output_dir)
|
| 1230 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 1231 |
+
|
| 1232 |
+
base_filename = package.slug
|
| 1233 |
+
corpus_filename = f"{base_filename}.jsonl"
|
| 1234 |
+
manifest_filename = f"{base_filename}.manifest.json"
|
| 1235 |
+
prompt_suite_filename = f"{base_filename}.prompts.jsonl"
|
| 1236 |
+
corpus_path = directory / corpus_filename
|
| 1237 |
+
manifest_path = directory / manifest_filename
|
| 1238 |
+
prompt_suite_path = directory / prompt_suite_filename
|
| 1239 |
+
|
| 1240 |
+
corpus_path.write_text(
|
| 1241 |
+
"\n".join(json.dumps(record, ensure_ascii=True) for record in package.corpus_records()) + "\n",
|
| 1242 |
+
encoding="utf-8",
|
| 1243 |
+
)
|
| 1244 |
+
manifest_path.write_text(
|
| 1245 |
+
json.dumps(package.manifest(corpus_filename=corpus_filename), indent=2),
|
| 1246 |
+
encoding="utf-8",
|
| 1247 |
+
)
|
| 1248 |
+
prompt_suite_path.write_text(
|
| 1249 |
+
"\n".join(json.dumps(record, ensure_ascii=True) for record in package.prompt_suite()) + "\n",
|
| 1250 |
+
encoding="utf-8",
|
| 1251 |
+
)
|
| 1252 |
+
|
| 1253 |
+
return {
|
| 1254 |
+
"corpus_path": str(corpus_path.resolve()),
|
| 1255 |
+
"manifest_path": str(manifest_path.resolve()),
|
| 1256 |
+
"prompt_suite_path": str(prompt_suite_path.resolve()),
|
| 1257 |
+
}
|
reframr/curriculum.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
reframr/datasets.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from .text_quality import clean_answer_text, clean_context_text, clean_training_text
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
TEXT_EXTENSIONS = {".txt", ".md", ".text"}
|
| 8 |
+
STRUCTURED_EXTENSIONS = {".jsonl", ".json"}
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _default_record_weight(record_type: str) -> int:
|
| 12 |
+
if record_type == "dialogue_turn":
|
| 13 |
+
return 2
|
| 14 |
+
if record_type == "instruction_answer":
|
| 15 |
+
return 2
|
| 16 |
+
if record_type == "preference_chosen":
|
| 17 |
+
return 3
|
| 18 |
+
if record_type == "preference_rejected":
|
| 19 |
+
return 0
|
| 20 |
+
return 1
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _record_repeat_count(record: object) -> int:
|
| 24 |
+
if not isinstance(record, dict):
|
| 25 |
+
return 1
|
| 26 |
+
if bool(record.get("drop")):
|
| 27 |
+
return 0
|
| 28 |
+
raw_weight = record.get("weight")
|
| 29 |
+
if raw_weight is not None:
|
| 30 |
+
try:
|
| 31 |
+
numeric = int(round(float(raw_weight)))
|
| 32 |
+
except (TypeError, ValueError):
|
| 33 |
+
numeric = 1
|
| 34 |
+
return max(0, min(8, numeric))
|
| 35 |
+
return _default_record_weight(str(record.get("record_type", "")))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _coerce_text_record(record: object) -> str:
|
| 39 |
+
if isinstance(record, str):
|
| 40 |
+
return clean_training_text(record.strip())
|
| 41 |
+
if isinstance(record, dict):
|
| 42 |
+
if "text" in record:
|
| 43 |
+
return clean_training_text(str(record["text"]).strip())
|
| 44 |
+
if "content" in record:
|
| 45 |
+
return clean_training_text(str(record["content"]).strip())
|
| 46 |
+
if "context" in record and "answer" in record:
|
| 47 |
+
context = clean_context_text(str(record["context"]).strip())
|
| 48 |
+
answer = clean_answer_text(str(record["answer"]).strip())
|
| 49 |
+
if context and answer:
|
| 50 |
+
return f"<reason> {context} <answer> {answer}"
|
| 51 |
+
return ""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _coerce_prompt_record(record: object) -> dict[str, object] | None:
|
| 55 |
+
if isinstance(record, str):
|
| 56 |
+
prompt = record.strip()
|
| 57 |
+
return {"prompt": prompt, "tags": []} if prompt else None
|
| 58 |
+
if isinstance(record, dict):
|
| 59 |
+
raw_prompt = record.get("prompt", record.get("context", ""))
|
| 60 |
+
prompt = clean_context_text(str(raw_prompt).strip())
|
| 61 |
+
if not prompt:
|
| 62 |
+
return None
|
| 63 |
+
raw_tags = record.get("tags", [])
|
| 64 |
+
tags = [str(tag) for tag in raw_tags] if isinstance(raw_tags, list) else []
|
| 65 |
+
normalized = dict(record)
|
| 66 |
+
normalized["prompt"] = prompt
|
| 67 |
+
normalized["tags"] = tags
|
| 68 |
+
return normalized
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def load_text_corpus(source: str | Path) -> str:
|
| 73 |
+
path = Path(source)
|
| 74 |
+
if path.is_dir():
|
| 75 |
+
parts = [
|
| 76 |
+
load_text_corpus(child)
|
| 77 |
+
for child in sorted(path.rglob("*"))
|
| 78 |
+
if child.is_file() and child.suffix.lower() in TEXT_EXTENSIONS | STRUCTURED_EXTENSIONS
|
| 79 |
+
]
|
| 80 |
+
return "\n".join(part for part in parts if part.strip())
|
| 81 |
+
|
| 82 |
+
suffix = path.suffix.lower()
|
| 83 |
+
if suffix in TEXT_EXTENSIONS:
|
| 84 |
+
return path.read_text(encoding="utf-8")
|
| 85 |
+
if suffix == ".jsonl":
|
| 86 |
+
lines = []
|
| 87 |
+
for line in path.read_text(encoding="utf-8").splitlines():
|
| 88 |
+
if not line.strip():
|
| 89 |
+
continue
|
| 90 |
+
record = json.loads(line)
|
| 91 |
+
text = _coerce_text_record(record)
|
| 92 |
+
if text:
|
| 93 |
+
lines.extend([text] * _record_repeat_count(record))
|
| 94 |
+
return "\n".join(lines)
|
| 95 |
+
if suffix == ".json":
|
| 96 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 97 |
+
if isinstance(payload, list):
|
| 98 |
+
parts: list[str] = []
|
| 99 |
+
for item in payload:
|
| 100 |
+
text = _coerce_text_record(item)
|
| 101 |
+
if text:
|
| 102 |
+
parts.extend([text] * _record_repeat_count(item))
|
| 103 |
+
return "\n".join(parts)
|
| 104 |
+
if isinstance(payload, dict):
|
| 105 |
+
if "texts" in payload and isinstance(payload["texts"], list):
|
| 106 |
+
parts: list[str] = []
|
| 107 |
+
for item in payload["texts"]:
|
| 108 |
+
text = _coerce_text_record(item)
|
| 109 |
+
if text:
|
| 110 |
+
parts.extend([text] * _record_repeat_count(item))
|
| 111 |
+
return "\n".join(parts)
|
| 112 |
+
if "records" in payload and isinstance(payload["records"], list):
|
| 113 |
+
parts: list[str] = []
|
| 114 |
+
for item in payload["records"]:
|
| 115 |
+
text = _coerce_text_record(item)
|
| 116 |
+
if text:
|
| 117 |
+
parts.extend([text] * _record_repeat_count(item))
|
| 118 |
+
return "\n".join(parts)
|
| 119 |
+
text = _coerce_text_record(payload)
|
| 120 |
+
if text:
|
| 121 |
+
return "\n".join([text] * _record_repeat_count(payload))
|
| 122 |
+
raise ValueError(f"Unsupported corpus source: {path}")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def load_prompt_suite(source: str | Path) -> list[dict[str, object]]:
|
| 126 |
+
path = Path(source)
|
| 127 |
+
suffix = path.suffix.lower()
|
| 128 |
+
prompts: list[dict[str, object]] = []
|
| 129 |
+
|
| 130 |
+
if suffix in TEXT_EXTENSIONS:
|
| 131 |
+
for line in path.read_text(encoding="utf-8").splitlines():
|
| 132 |
+
record = _coerce_prompt_record(line)
|
| 133 |
+
if record is not None:
|
| 134 |
+
prompts.append(record)
|
| 135 |
+
return prompts
|
| 136 |
+
|
| 137 |
+
if suffix == ".jsonl":
|
| 138 |
+
for line in path.read_text(encoding="utf-8").splitlines():
|
| 139 |
+
if not line.strip():
|
| 140 |
+
continue
|
| 141 |
+
record = _coerce_prompt_record(json.loads(line))
|
| 142 |
+
if record is not None:
|
| 143 |
+
prompts.append(record)
|
| 144 |
+
return prompts
|
| 145 |
+
|
| 146 |
+
if suffix == ".json":
|
| 147 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 148 |
+
if isinstance(payload, list):
|
| 149 |
+
for item in payload:
|
| 150 |
+
record = _coerce_prompt_record(item)
|
| 151 |
+
if record is not None:
|
| 152 |
+
prompts.append(record)
|
| 153 |
+
return prompts
|
| 154 |
+
if isinstance(payload, dict):
|
| 155 |
+
if "prompts" in payload and isinstance(payload["prompts"], list):
|
| 156 |
+
for item in payload["prompts"]:
|
| 157 |
+
record = _coerce_prompt_record(item)
|
| 158 |
+
if record is not None:
|
| 159 |
+
prompts.append(record)
|
| 160 |
+
return prompts
|
| 161 |
+
record = _coerce_prompt_record(payload)
|
| 162 |
+
if record is not None:
|
| 163 |
+
return [record]
|
| 164 |
+
|
| 165 |
+
raise ValueError(f"Unsupported prompt suite: {path}")
|
reframr/embeddings.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
from .corpus import build_cooccurrence_matrix, build_vocabulary, tokenize
|
| 7 |
+
from .linalg import Matrix, Vector, mean, np, top_k_eigenpairs_symmetric, zeros
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from scipy import sparse as scipy_sparse
|
| 11 |
+
from scipy.sparse.linalg import svds as scipy_svds
|
| 12 |
+
except (ImportError, ModuleNotFoundError, OSError):
|
| 13 |
+
scipy_sparse = None
|
| 14 |
+
scipy_svds = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
SKETCHED_EMBEDDING_VOCAB_THRESHOLD = 2048
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _remove_common_embedding_axis(embeddings: object, row_strength: object | None = None) -> object:
|
| 21 |
+
if np is None:
|
| 22 |
+
return embeddings
|
| 23 |
+
values = np.asarray(embeddings, dtype=np.float64)
|
| 24 |
+
if values.size == 0 or len(values.shape) != 2:
|
| 25 |
+
return values
|
| 26 |
+
norms = np.linalg.norm(values, axis=1)
|
| 27 |
+
nonzero = norms > 1e-12
|
| 28 |
+
values[nonzero] /= norms[nonzero, None]
|
| 29 |
+
if row_strength is not None:
|
| 30 |
+
strength = np.asarray(row_strength, dtype=np.float64)
|
| 31 |
+
if strength.shape[0] == values.shape[0]:
|
| 32 |
+
values[nonzero] *= np.log1p(strength[nonzero])[:, None]
|
| 33 |
+
|
| 34 |
+
common_axis = values.mean(axis=0, keepdims=True)
|
| 35 |
+
values = values - common_axis
|
| 36 |
+
norms = np.linalg.norm(values, axis=1)
|
| 37 |
+
nonzero = norms > 1e-12
|
| 38 |
+
values[nonzero] /= norms[nonzero, None]
|
| 39 |
+
if row_strength is not None:
|
| 40 |
+
strength = np.asarray(row_strength, dtype=np.float64)
|
| 41 |
+
if strength.shape[0] == values.shape[0]:
|
| 42 |
+
values[nonzero] *= np.log1p(strength[nonzero])[:, None]
|
| 43 |
+
return values
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _sketched_sparse_ppmi_embedding(ppmi: object, embedding_dim: int) -> object:
|
| 47 |
+
coo = ppmi.tocoo()
|
| 48 |
+
rows = coo.row.astype(np.int64, copy=False)
|
| 49 |
+
cols = coo.col.astype(np.int64, copy=False)
|
| 50 |
+
values = coo.data.astype(np.float64, copy=False)
|
| 51 |
+
embeddings = np.zeros((ppmi.shape[0], embedding_dim), dtype=np.float64)
|
| 52 |
+
if embedding_dim <= 0 or values.size == 0:
|
| 53 |
+
return embeddings
|
| 54 |
+
|
| 55 |
+
buckets = ((cols * 1103515245 + 12345) % embedding_dim).astype(np.int64, copy=False)
|
| 56 |
+
signs = np.where(((cols * 214013 + 2531011) & 1) == 0, 1.0, -1.0)
|
| 57 |
+
np.add.at(embeddings, (rows, buckets), values * signs)
|
| 58 |
+
|
| 59 |
+
row_strength = np.sqrt(np.asarray(ppmi.sum(axis=1)).ravel())
|
| 60 |
+
return _remove_common_embedding_axis(embeddings, row_strength)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def fit_sketched_ppmi_embedding_from_counts(
|
| 64 |
+
id_to_token: list[str],
|
| 65 |
+
rows: dict[int, dict[int, float]],
|
| 66 |
+
*,
|
| 67 |
+
embedding_dim: int,
|
| 68 |
+
) -> EmbeddingModel:
|
| 69 |
+
if not id_to_token:
|
| 70 |
+
raise ValueError("Cannot fit REFRAMR embeddings without a vocabulary.")
|
| 71 |
+
if embedding_dim <= 0:
|
| 72 |
+
raise ValueError("Embedding dimension must be positive.")
|
| 73 |
+
|
| 74 |
+
size = len(id_to_token)
|
| 75 |
+
token_to_id = {token: index for index, token in enumerate(id_to_token)}
|
| 76 |
+
if np is None:
|
| 77 |
+
embeddings = zeros(size, embedding_dim)
|
| 78 |
+
row_sums = [0.0 for _ in range(size)]
|
| 79 |
+
for row, columns in rows.items():
|
| 80 |
+
row_sums[row] = sum(columns.values())
|
| 81 |
+
total = sum(row_sums)
|
| 82 |
+
if total <= 0.0:
|
| 83 |
+
return EmbeddingModel(token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[])
|
| 84 |
+
for row, columns in rows.items():
|
| 85 |
+
for col, count in columns.items():
|
| 86 |
+
denominator = row_sums[row] * row_sums[col]
|
| 87 |
+
if count <= 0.0 or denominator <= 0.0:
|
| 88 |
+
continue
|
| 89 |
+
value = math.log((count * total) / denominator)
|
| 90 |
+
if value <= 0.0:
|
| 91 |
+
continue
|
| 92 |
+
bucket = (col * 1103515245 + 12345) % embedding_dim
|
| 93 |
+
sign = 1.0 if ((col * 214013 + 2531011) & 1) == 0 else -1.0
|
| 94 |
+
embeddings[row][bucket] += value * sign
|
| 95 |
+
return EmbeddingModel(token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[])
|
| 96 |
+
|
| 97 |
+
embeddings = np.zeros((size, embedding_dim), dtype=np.float64)
|
| 98 |
+
row_sums = np.zeros(size, dtype=np.float64)
|
| 99 |
+
for row, columns in rows.items():
|
| 100 |
+
row_sums[row] = sum(columns.values())
|
| 101 |
+
total = float(row_sums.sum())
|
| 102 |
+
if total <= 0.0:
|
| 103 |
+
return EmbeddingModel(token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[])
|
| 104 |
+
|
| 105 |
+
for row, columns in rows.items():
|
| 106 |
+
if not columns or row_sums[row] <= 0.0:
|
| 107 |
+
continue
|
| 108 |
+
cols = np.fromiter(columns.keys(), dtype=np.int64)
|
| 109 |
+
counts = np.fromiter(columns.values(), dtype=np.float64)
|
| 110 |
+
denominators = row_sums[row] * row_sums[cols]
|
| 111 |
+
valid = (counts > 0.0) & (denominators > 0.0)
|
| 112 |
+
if not np.any(valid):
|
| 113 |
+
continue
|
| 114 |
+
cols = cols[valid]
|
| 115 |
+
values = np.log((counts[valid] * total) / denominators[valid])
|
| 116 |
+
positive = values > 0.0
|
| 117 |
+
if not np.any(positive):
|
| 118 |
+
continue
|
| 119 |
+
cols = cols[positive]
|
| 120 |
+
values = values[positive]
|
| 121 |
+
buckets = ((cols * 1103515245 + 12345) % embedding_dim).astype(np.int64, copy=False)
|
| 122 |
+
signs = np.where(((cols * 214013 + 2531011) & 1) == 0, 1.0, -1.0)
|
| 123 |
+
np.add.at(embeddings[row], buckets, values * signs)
|
| 124 |
+
|
| 125 |
+
embeddings = _remove_common_embedding_axis(embeddings, row_sums)
|
| 126 |
+
return EmbeddingModel(
|
| 127 |
+
token_to_id=token_to_id,
|
| 128 |
+
id_to_token=id_to_token,
|
| 129 |
+
embeddings=embeddings,
|
| 130 |
+
ppmi_matrix=[],
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _positive_ppmi_values(
|
| 135 |
+
*,
|
| 136 |
+
row: int,
|
| 137 |
+
columns: dict[int, float],
|
| 138 |
+
row_sums: object,
|
| 139 |
+
total: float,
|
| 140 |
+
) -> tuple[object, object]:
|
| 141 |
+
cols = np.fromiter(columns.keys(), dtype=np.int64)
|
| 142 |
+
counts = np.fromiter(columns.values(), dtype=np.float64)
|
| 143 |
+
if cols.size == 0:
|
| 144 |
+
return cols, counts
|
| 145 |
+
denominators = float(row_sums[row]) * row_sums[cols]
|
| 146 |
+
valid = (counts > 0.0) & (denominators > 0.0)
|
| 147 |
+
if not np.any(valid):
|
| 148 |
+
return cols[:0], counts[:0]
|
| 149 |
+
cols = cols[valid]
|
| 150 |
+
values = np.log((counts[valid] * total) / denominators[valid])
|
| 151 |
+
positive = values > 0.0
|
| 152 |
+
return cols[positive], values[positive]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def fit_randomized_ppmi_embedding_from_counts(
|
| 156 |
+
id_to_token: list[str],
|
| 157 |
+
rows: dict[int, dict[int, float]],
|
| 158 |
+
*,
|
| 159 |
+
embedding_dim: int,
|
| 160 |
+
oversampling: int = 32,
|
| 161 |
+
) -> EmbeddingModel:
|
| 162 |
+
if np is None:
|
| 163 |
+
return fit_sketched_ppmi_embedding_from_counts(
|
| 164 |
+
id_to_token,
|
| 165 |
+
rows,
|
| 166 |
+
embedding_dim=embedding_dim,
|
| 167 |
+
)
|
| 168 |
+
if not id_to_token:
|
| 169 |
+
raise ValueError("Cannot fit REFRAMR embeddings without a vocabulary.")
|
| 170 |
+
if embedding_dim <= 0:
|
| 171 |
+
raise ValueError("Embedding dimension must be positive.")
|
| 172 |
+
|
| 173 |
+
size = len(id_to_token)
|
| 174 |
+
token_to_id = {token: index for index, token in enumerate(id_to_token)}
|
| 175 |
+
row_sums = np.zeros(size, dtype=np.float64)
|
| 176 |
+
for row, columns in rows.items():
|
| 177 |
+
row_sums[row] = sum(columns.values())
|
| 178 |
+
total = float(row_sums.sum())
|
| 179 |
+
if total <= 0.0:
|
| 180 |
+
return EmbeddingModel(
|
| 181 |
+
token_to_id=token_to_id,
|
| 182 |
+
id_to_token=id_to_token,
|
| 183 |
+
embeddings=np.zeros((size, embedding_dim), dtype=np.float64),
|
| 184 |
+
ppmi_matrix=[],
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
width = min(size, max(embedding_dim, embedding_dim + oversampling))
|
| 188 |
+
rng = np.random.default_rng(1729 + size * 31 + embedding_dim)
|
| 189 |
+
omega = rng.standard_normal((size, width)).astype(np.float64, copy=False)
|
| 190 |
+
sketch = np.zeros((size, width), dtype=np.float64)
|
| 191 |
+
ppmi_cache: dict[int, tuple[object, object]] = {}
|
| 192 |
+
for row, columns in rows.items():
|
| 193 |
+
if not columns or row_sums[row] <= 0.0:
|
| 194 |
+
continue
|
| 195 |
+
cols, values = _positive_ppmi_values(
|
| 196 |
+
row=row,
|
| 197 |
+
columns=columns,
|
| 198 |
+
row_sums=row_sums,
|
| 199 |
+
total=total,
|
| 200 |
+
)
|
| 201 |
+
if values.size == 0:
|
| 202 |
+
continue
|
| 203 |
+
ppmi_cache[row] = (cols, values)
|
| 204 |
+
sketch[row] = values @ omega[cols]
|
| 205 |
+
|
| 206 |
+
if not ppmi_cache:
|
| 207 |
+
return EmbeddingModel(
|
| 208 |
+
token_to_id=token_to_id,
|
| 209 |
+
id_to_token=id_to_token,
|
| 210 |
+
embeddings=np.zeros((size, embedding_dim), dtype=np.float64),
|
| 211 |
+
ppmi_matrix=[],
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
basis, _ = np.linalg.qr(sketch, mode="reduced")
|
| 215 |
+
compressed = np.zeros((basis.shape[1], size), dtype=np.float64)
|
| 216 |
+
for row, (cols, values) in ppmi_cache.items():
|
| 217 |
+
compressed[:, cols] += basis[row, :, None] * values[None, :]
|
| 218 |
+
|
| 219 |
+
left_small, singular_values, _ = np.linalg.svd(compressed, full_matrices=False)
|
| 220 |
+
left = basis @ left_small
|
| 221 |
+
width = min(embedding_dim, left.shape[1], singular_values.shape[0])
|
| 222 |
+
embeddings = np.zeros((size, embedding_dim), dtype=np.float64)
|
| 223 |
+
if width > 0:
|
| 224 |
+
embeddings[:, :width] = left[:, :width] * np.sqrt(np.maximum(singular_values[:width], 0.0))[None, :]
|
| 225 |
+
embeddings = _remove_common_embedding_axis(embeddings, np.sqrt(row_sums))
|
| 226 |
+
return EmbeddingModel(
|
| 227 |
+
token_to_id=token_to_id,
|
| 228 |
+
id_to_token=id_to_token,
|
| 229 |
+
embeddings=embeddings,
|
| 230 |
+
ppmi_matrix=[],
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def positive_pointwise_mutual_information(matrix: Matrix) -> Matrix:
|
| 235 |
+
if scipy_sparse is not None and scipy_sparse.issparse(matrix):
|
| 236 |
+
counts = matrix.tocoo()
|
| 237 |
+
if counts.nnz == 0:
|
| 238 |
+
return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64)
|
| 239 |
+
row_sums = np.asarray(matrix.sum(axis=1)).ravel()
|
| 240 |
+
total = float(row_sums.sum())
|
| 241 |
+
if total == 0.0:
|
| 242 |
+
return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64)
|
| 243 |
+
denominators = row_sums[counts.row] * row_sums[counts.col]
|
| 244 |
+
valid = (counts.data > 0.0) & (denominators > 0.0)
|
| 245 |
+
if not np.any(valid):
|
| 246 |
+
return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64)
|
| 247 |
+
ratios = (counts.data[valid] * total) / denominators[valid]
|
| 248 |
+
data = np.maximum(np.log(ratios), 0.0)
|
| 249 |
+
keep = data > 0.0
|
| 250 |
+
if not np.any(keep):
|
| 251 |
+
return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64)
|
| 252 |
+
return scipy_sparse.coo_matrix(
|
| 253 |
+
(
|
| 254 |
+
data[keep],
|
| 255 |
+
(counts.row[valid][keep], counts.col[valid][keep]),
|
| 256 |
+
),
|
| 257 |
+
shape=counts.shape,
|
| 258 |
+
dtype=np.float64,
|
| 259 |
+
).tocsr()
|
| 260 |
+
|
| 261 |
+
if not matrix:
|
| 262 |
+
return []
|
| 263 |
+
if np is not None:
|
| 264 |
+
counts = np.asarray(matrix, dtype=np.float64)
|
| 265 |
+
row_sums = counts.sum(axis=1)
|
| 266 |
+
total = float(row_sums.sum())
|
| 267 |
+
if total == 0.0:
|
| 268 |
+
return np.zeros_like(counts).tolist()
|
| 269 |
+
denominator = np.outer(row_sums, row_sums)
|
| 270 |
+
valid = (counts > 0.0) & (denominator > 0.0)
|
| 271 |
+
ppmi = np.zeros_like(counts)
|
| 272 |
+
with np.errstate(divide="ignore", invalid="ignore"):
|
| 273 |
+
ratios = np.divide(
|
| 274 |
+
counts * total,
|
| 275 |
+
denominator,
|
| 276 |
+
out=np.ones_like(counts),
|
| 277 |
+
where=valid,
|
| 278 |
+
)
|
| 279 |
+
ppmi[valid] = np.maximum(np.log(ratios[valid]), 0.0)
|
| 280 |
+
return ppmi.tolist()
|
| 281 |
+
|
| 282 |
+
row_sums = [sum(row) for row in matrix]
|
| 283 |
+
total = sum(row_sums)
|
| 284 |
+
if total == 0.0:
|
| 285 |
+
return zeros(len(matrix), len(matrix))
|
| 286 |
+
|
| 287 |
+
ppmi = zeros(len(matrix), len(matrix))
|
| 288 |
+
for row in range(len(matrix)):
|
| 289 |
+
for col in range(len(matrix[row])):
|
| 290 |
+
count = matrix[row][col]
|
| 291 |
+
if count <= 0.0 or row_sums[row] == 0.0 or row_sums[col] == 0.0:
|
| 292 |
+
continue
|
| 293 |
+
p_ij = count / total
|
| 294 |
+
p_i = row_sums[row] / total
|
| 295 |
+
p_j = row_sums[col] / total
|
| 296 |
+
value = math.log(p_ij / (p_i * p_j))
|
| 297 |
+
ppmi[row][col] = max(0.0, value)
|
| 298 |
+
return ppmi
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
@dataclass(slots=True)
|
| 302 |
+
class EmbeddingModel:
|
| 303 |
+
token_to_id: dict[str, int]
|
| 304 |
+
id_to_token: list[str]
|
| 305 |
+
embeddings: Matrix
|
| 306 |
+
ppmi_matrix: Matrix
|
| 307 |
+
|
| 308 |
+
def vector(self, token: str) -> Vector:
|
| 309 |
+
index = self.token_to_id.get(token)
|
| 310 |
+
if index is None and token.lower() != token:
|
| 311 |
+
index = self.token_to_id.get(token.lower())
|
| 312 |
+
if index is None:
|
| 313 |
+
return [0.0 for _ in range(self.dimension)]
|
| 314 |
+
row = self.embeddings[index]
|
| 315 |
+
return row.astype(float).tolist() if hasattr(row, "tolist") else row[:]
|
| 316 |
+
|
| 317 |
+
@property
|
| 318 |
+
def dimension(self) -> int:
|
| 319 |
+
if hasattr(self.embeddings, "shape"):
|
| 320 |
+
return int(self.embeddings.shape[1]) if len(self.embeddings.shape) > 1 else 0
|
| 321 |
+
return len(self.embeddings[0]) if self.embeddings else 0
|
| 322 |
+
|
| 323 |
+
@property
|
| 324 |
+
def projection_axis(self) -> Vector:
|
| 325 |
+
if hasattr(self.embeddings, "shape"):
|
| 326 |
+
if int(self.embeddings.shape[0]) == 0:
|
| 327 |
+
return []
|
| 328 |
+
return self.embeddings.mean(axis=0).astype(float).tolist()
|
| 329 |
+
if not self.embeddings:
|
| 330 |
+
return []
|
| 331 |
+
return [
|
| 332 |
+
mean([row[column] for row in self.embeddings])
|
| 333 |
+
for column in range(self.dimension)
|
| 334 |
+
]
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def fit_ppmi_embedding(
|
| 338 |
+
text: str,
|
| 339 |
+
*,
|
| 340 |
+
embedding_dim: int,
|
| 341 |
+
window_size: int,
|
| 342 |
+
min_frequency: int = 1,
|
| 343 |
+
max_vocab: int | None = None,
|
| 344 |
+
) -> EmbeddingModel:
|
| 345 |
+
tokens = tokenize(text)
|
| 346 |
+
if not tokens:
|
| 347 |
+
raise ValueError("Cannot fit REFRAMR embeddings on empty text.")
|
| 348 |
+
|
| 349 |
+
return fit_ppmi_embedding_from_tokens(
|
| 350 |
+
tokens,
|
| 351 |
+
embedding_dim=embedding_dim,
|
| 352 |
+
window_size=window_size,
|
| 353 |
+
min_frequency=min_frequency,
|
| 354 |
+
max_vocab=max_vocab,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def fit_ppmi_embedding_from_tokens(
|
| 359 |
+
tokens: list[str],
|
| 360 |
+
*,
|
| 361 |
+
embedding_dim: int,
|
| 362 |
+
window_size: int,
|
| 363 |
+
min_frequency: int = 1,
|
| 364 |
+
max_vocab: int | None = None,
|
| 365 |
+
) -> EmbeddingModel:
|
| 366 |
+
if not tokens:
|
| 367 |
+
raise ValueError("Cannot fit REFRAMR embeddings on an empty token stream.")
|
| 368 |
+
|
| 369 |
+
token_to_id, id_to_token = build_vocabulary(tokens, min_frequency, max_vocab)
|
| 370 |
+
cooccurrence = build_cooccurrence_matrix(tokens, token_to_id, window_size)
|
| 371 |
+
ppmi = positive_pointwise_mutual_information(cooccurrence)
|
| 372 |
+
eigenpairs = top_k_eigenpairs_symmetric(ppmi, embedding_dim)
|
| 373 |
+
|
| 374 |
+
embeddings = zeros(len(id_to_token), embedding_dim)
|
| 375 |
+
for component, (eigenvalue, eigenvector) in enumerate(eigenpairs):
|
| 376 |
+
scale = math.sqrt(max(eigenvalue, 0.0))
|
| 377 |
+
for row in range(len(id_to_token)):
|
| 378 |
+
embeddings[row][component] = eigenvector[row] * scale
|
| 379 |
+
if np is not None:
|
| 380 |
+
embeddings = _remove_common_embedding_axis(np.asarray(embeddings, dtype=np.float64))
|
| 381 |
+
|
| 382 |
+
return EmbeddingModel(
|
| 383 |
+
token_to_id=token_to_id,
|
| 384 |
+
id_to_token=id_to_token,
|
| 385 |
+
embeddings=embeddings,
|
| 386 |
+
ppmi_matrix=ppmi,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def fit_ppmi_embedding_from_cooccurrence(
|
| 391 |
+
id_to_token: list[str],
|
| 392 |
+
cooccurrence: Matrix,
|
| 393 |
+
*,
|
| 394 |
+
embedding_dim: int,
|
| 395 |
+
) -> EmbeddingModel:
|
| 396 |
+
if not id_to_token:
|
| 397 |
+
raise ValueError("Cannot fit REFRAMR embeddings without a vocabulary.")
|
| 398 |
+
|
| 399 |
+
ppmi = positive_pointwise_mutual_information(cooccurrence)
|
| 400 |
+
if scipy_sparse is not None and scipy_sparse.issparse(ppmi):
|
| 401 |
+
embedding_width = min(embedding_dim, len(id_to_token))
|
| 402 |
+
if len(id_to_token) >= SKETCHED_EMBEDDING_VOCAB_THRESHOLD or embedding_width >= 128:
|
| 403 |
+
embeddings = _sketched_sparse_ppmi_embedding(ppmi, embedding_dim)
|
| 404 |
+
return EmbeddingModel(
|
| 405 |
+
token_to_id={token: index for index, token in enumerate(id_to_token)},
|
| 406 |
+
id_to_token=id_to_token,
|
| 407 |
+
embeddings=embeddings,
|
| 408 |
+
ppmi_matrix=[],
|
| 409 |
+
)
|
| 410 |
+
embeddings = zeros(len(id_to_token), embedding_dim)
|
| 411 |
+
if embedding_width <= 0 or ppmi.nnz == 0:
|
| 412 |
+
return EmbeddingModel(
|
| 413 |
+
token_to_id={token: index for index, token in enumerate(id_to_token)},
|
| 414 |
+
id_to_token=id_to_token,
|
| 415 |
+
embeddings=embeddings,
|
| 416 |
+
ppmi_matrix=[],
|
| 417 |
+
)
|
| 418 |
+
if embedding_width < min(ppmi.shape) and scipy_svds is not None:
|
| 419 |
+
left, values, _ = scipy_svds(ppmi.asfptype(), k=embedding_width, which="LM")
|
| 420 |
+
order = np.argsort(values)[::-1]
|
| 421 |
+
for component, source_index in enumerate(order):
|
| 422 |
+
scale = math.sqrt(max(float(values[source_index]), 0.0))
|
| 423 |
+
column = left[:, source_index]
|
| 424 |
+
for row, value in enumerate(column):
|
| 425 |
+
embeddings[row][component] = float(value) * scale
|
| 426 |
+
else:
|
| 427 |
+
dense = ppmi.toarray().tolist()
|
| 428 |
+
eigenpairs = top_k_eigenpairs_symmetric(dense, embedding_width)
|
| 429 |
+
for component, (eigenvalue, eigenvector) in enumerate(eigenpairs):
|
| 430 |
+
scale = math.sqrt(max(eigenvalue, 0.0))
|
| 431 |
+
for row in range(len(id_to_token)):
|
| 432 |
+
embeddings[row][component] = eigenvector[row] * scale
|
| 433 |
+
if np is not None:
|
| 434 |
+
embeddings = _remove_common_embedding_axis(np.asarray(embeddings, dtype=np.float64))
|
| 435 |
+
return EmbeddingModel(
|
| 436 |
+
token_to_id={token: index for index, token in enumerate(id_to_token)},
|
| 437 |
+
id_to_token=id_to_token,
|
| 438 |
+
embeddings=embeddings,
|
| 439 |
+
ppmi_matrix=[],
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
eigenpairs = top_k_eigenpairs_symmetric(ppmi, embedding_dim)
|
| 443 |
+
|
| 444 |
+
embeddings = zeros(len(id_to_token), embedding_dim)
|
| 445 |
+
for component, (eigenvalue, eigenvector) in enumerate(eigenpairs):
|
| 446 |
+
scale = math.sqrt(max(eigenvalue, 0.0))
|
| 447 |
+
for row in range(len(id_to_token)):
|
| 448 |
+
embeddings[row][component] = eigenvector[row] * scale
|
| 449 |
+
if np is not None:
|
| 450 |
+
embeddings = _remove_common_embedding_axis(np.asarray(embeddings, dtype=np.float64))
|
| 451 |
+
|
| 452 |
+
return EmbeddingModel(
|
| 453 |
+
token_to_id={token: index for index, token in enumerate(id_to_token)},
|
| 454 |
+
id_to_token=id_to_token,
|
| 455 |
+
embeddings=embeddings,
|
| 456 |
+
ppmi_matrix=ppmi,
|
| 457 |
+
)
|
reframr/evaluation.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from .model import ReframrModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_manifest(path: str | Path) -> dict[str, object]:
|
| 8 |
+
return json.loads(Path(path).read_text(encoding="utf-8"))
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _expected_next_token(model: ReframrModel, expected_text: str) -> str:
|
| 12 |
+
assert model.tokenizer is not None
|
| 13 |
+
encoded = model.tokenizer.encode(f" {expected_text}")
|
| 14 |
+
return encoded[0] if encoded else ""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _normalize_text(text: str) -> str:
|
| 18 |
+
return " ".join(text.casefold().split())
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _word_ngrams(words: list[str], size: int) -> list[tuple[str, ...]]:
|
| 22 |
+
if size <= 0 or len(words) < size:
|
| 23 |
+
return []
|
| 24 |
+
return [tuple(words[index : index + size]) for index in range(len(words) - size + 1)]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _distinct_ratio(words: list[str], size: int) -> float:
|
| 28 |
+
grams = _word_ngrams(words, size)
|
| 29 |
+
if not grams:
|
| 30 |
+
return 0.0
|
| 31 |
+
return len(set(grams)) / len(grams)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _repetition_ratio(words: list[str], size: int) -> float:
|
| 35 |
+
grams = _word_ngrams(words, size)
|
| 36 |
+
if not grams:
|
| 37 |
+
return 0.0
|
| 38 |
+
repeated = len(grams) - len(set(grams))
|
| 39 |
+
return repeated / len(grams)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _open_ended_score(
|
| 43 |
+
model: ReframrModel,
|
| 44 |
+
sample: dict[str, object],
|
| 45 |
+
*,
|
| 46 |
+
reasoning_mode: str | None,
|
| 47 |
+
) -> dict[str, object]:
|
| 48 |
+
generated = model.generate_text(
|
| 49 |
+
str(sample["context"]),
|
| 50 |
+
max_tokens=int(sample.get("max_tokens", 56)),
|
| 51 |
+
reasoning_mode=reasoning_mode,
|
| 52 |
+
)
|
| 53 |
+
normalized = _normalize_text(generated)
|
| 54 |
+
required_groups = [
|
| 55 |
+
[str(term).casefold() for term in group]
|
| 56 |
+
for group in sample.get("required_groups", [])
|
| 57 |
+
]
|
| 58 |
+
satisfied_groups = sum(
|
| 59 |
+
1
|
| 60 |
+
for group in required_groups
|
| 61 |
+
if any(term in normalized for term in group)
|
| 62 |
+
)
|
| 63 |
+
group_coverage = (
|
| 64 |
+
satisfied_groups / len(required_groups) if required_groups else 0.0
|
| 65 |
+
)
|
| 66 |
+
punctuation_hit = any(mark in generated for mark in ".,;:?!")
|
| 67 |
+
min_words = int(sample.get("min_words", 12))
|
| 68 |
+
min_word_hit = len(generated.split()) >= min_words
|
| 69 |
+
banned_phrases = [str(phrase) for phrase in sample.get("banned_phrases", [])]
|
| 70 |
+
exact_copy = any(normalized == _normalize_text(phrase) for phrase in banned_phrases)
|
| 71 |
+
novelty_hit = not exact_copy
|
| 72 |
+
require_punctuation = bool(sample.get("require_punctuation", True))
|
| 73 |
+
|
| 74 |
+
score_components = [
|
| 75 |
+
group_coverage,
|
| 76 |
+
1.0 if min_word_hit else 0.0,
|
| 77 |
+
1.0 if novelty_hit else 0.0,
|
| 78 |
+
]
|
| 79 |
+
if require_punctuation:
|
| 80 |
+
score_components.append(1.0 if punctuation_hit else 0.0)
|
| 81 |
+
|
| 82 |
+
return {
|
| 83 |
+
"section": str(sample["section"]),
|
| 84 |
+
"context": str(sample["context"]),
|
| 85 |
+
"generated_text": generated,
|
| 86 |
+
"group_coverage": group_coverage,
|
| 87 |
+
"punctuation_hit": punctuation_hit,
|
| 88 |
+
"min_word_hit": min_word_hit,
|
| 89 |
+
"exact_copy": exact_copy,
|
| 90 |
+
"score": sum(score_components) / len(score_components) if score_components else 0.0,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def evaluate_manifest(
|
| 95 |
+
model: ReframrModel,
|
| 96 |
+
manifest: dict[str, object],
|
| 97 |
+
*,
|
| 98 |
+
reasoning_mode: str | None = None,
|
| 99 |
+
top_k: int = 5,
|
| 100 |
+
) -> dict[str, object]:
|
| 101 |
+
results: dict[str, object] = {
|
| 102 |
+
"corpus_name": manifest["name"],
|
| 103 |
+
"reasoning_mode": reasoning_mode or model.config.default_reasoning_profile,
|
| 104 |
+
"splits": {},
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
splits = manifest["splits"]
|
| 108 |
+
for split_name in ("memorization", "generalization"):
|
| 109 |
+
samples = splits[split_name]
|
| 110 |
+
top1_hits = 0
|
| 111 |
+
topk_hits = 0
|
| 112 |
+
expected_probabilities = []
|
| 113 |
+
|
| 114 |
+
for sample in samples:
|
| 115 |
+
distribution = model.predict_next_token_distribution(
|
| 116 |
+
sample["context"],
|
| 117 |
+
reasoning_mode=reasoning_mode,
|
| 118 |
+
)
|
| 119 |
+
ranked = sorted(distribution.items(), key=lambda item: item[1], reverse=True)
|
| 120 |
+
predicted = ranked[0][0] if ranked else ""
|
| 121 |
+
top_tokens = [token for token, _ in ranked[:top_k]]
|
| 122 |
+
expected = _expected_next_token(model, sample["expected"])
|
| 123 |
+
expected_probability = distribution.get(expected, 0.0)
|
| 124 |
+
|
| 125 |
+
if predicted == expected:
|
| 126 |
+
top1_hits += 1
|
| 127 |
+
if expected in top_tokens:
|
| 128 |
+
topk_hits += 1
|
| 129 |
+
expected_probabilities.append(expected_probability)
|
| 130 |
+
|
| 131 |
+
sample_count = len(samples)
|
| 132 |
+
mean_expected_probability = (
|
| 133 |
+
sum(expected_probabilities) / sample_count if sample_count else 0.0
|
| 134 |
+
)
|
| 135 |
+
results["splits"][split_name] = {
|
| 136 |
+
"sample_count": sample_count,
|
| 137 |
+
"top1_accuracy": top1_hits / sample_count if sample_count else 0.0,
|
| 138 |
+
"topk_accuracy": topk_hits / sample_count if sample_count else 0.0,
|
| 139 |
+
"mean_expected_probability": mean_expected_probability,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
open_ended_samples = splits.get("open_ended", [])
|
| 143 |
+
if open_ended_samples:
|
| 144 |
+
sample_results = [
|
| 145 |
+
_open_ended_score(
|
| 146 |
+
model,
|
| 147 |
+
sample,
|
| 148 |
+
reasoning_mode=reasoning_mode,
|
| 149 |
+
)
|
| 150 |
+
for sample in open_ended_samples
|
| 151 |
+
]
|
| 152 |
+
sample_count = len(sample_results)
|
| 153 |
+
results["open_ended"] = {
|
| 154 |
+
"sample_count": sample_count,
|
| 155 |
+
"mean_score": (
|
| 156 |
+
sum(float(sample["score"]) for sample in sample_results) / sample_count
|
| 157 |
+
if sample_count
|
| 158 |
+
else 0.0
|
| 159 |
+
),
|
| 160 |
+
"mean_group_coverage": (
|
| 161 |
+
sum(float(sample["group_coverage"]) for sample in sample_results) / sample_count
|
| 162 |
+
if sample_count
|
| 163 |
+
else 0.0
|
| 164 |
+
),
|
| 165 |
+
"punctuation_rate": (
|
| 166 |
+
sum(1 for sample in sample_results if bool(sample["punctuation_hit"])) / sample_count
|
| 167 |
+
if sample_count
|
| 168 |
+
else 0.0
|
| 169 |
+
),
|
| 170 |
+
"min_word_rate": (
|
| 171 |
+
sum(1 for sample in sample_results if bool(sample["min_word_hit"])) / sample_count
|
| 172 |
+
if sample_count
|
| 173 |
+
else 0.0
|
| 174 |
+
),
|
| 175 |
+
"exact_copy_rate": (
|
| 176 |
+
sum(1 for sample in sample_results if bool(sample["exact_copy"])) / sample_count
|
| 177 |
+
if sample_count
|
| 178 |
+
else 0.0
|
| 179 |
+
),
|
| 180 |
+
"samples": sample_results,
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
return results
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def benchmark_open_prompts(
|
| 187 |
+
model: ReframrModel,
|
| 188 |
+
prompts: list[dict[str, object]],
|
| 189 |
+
*,
|
| 190 |
+
reasoning_mode: str | None = None,
|
| 191 |
+
max_tokens: int = 64,
|
| 192 |
+
temperature: float = 0.82,
|
| 193 |
+
top_k: int = 24,
|
| 194 |
+
top_p: float = 0.92,
|
| 195 |
+
repetition_penalty: float = 1.18,
|
| 196 |
+
) -> dict[str, object]:
|
| 197 |
+
samples: list[dict[str, object]] = []
|
| 198 |
+
for item in prompts:
|
| 199 |
+
prompt = str(item["prompt"])
|
| 200 |
+
generated = model.generate_text(
|
| 201 |
+
prompt,
|
| 202 |
+
max_tokens=max_tokens,
|
| 203 |
+
reasoning_mode=reasoning_mode,
|
| 204 |
+
temperature=temperature,
|
| 205 |
+
top_k=top_k,
|
| 206 |
+
top_p=top_p,
|
| 207 |
+
repetition_penalty=repetition_penalty,
|
| 208 |
+
)
|
| 209 |
+
words = generated.split()
|
| 210 |
+
samples.append(
|
| 211 |
+
{
|
| 212 |
+
"prompt": prompt,
|
| 213 |
+
"tags": [str(tag) for tag in item.get("tags", [])],
|
| 214 |
+
"generated_text": generated,
|
| 215 |
+
"word_count": len(words),
|
| 216 |
+
"char_count": len(generated),
|
| 217 |
+
"punctuation_hit": any(mark in generated for mark in ".,;:?!"),
|
| 218 |
+
"distinct_2": _distinct_ratio(words, 2),
|
| 219 |
+
"distinct_3": _distinct_ratio(words, 3),
|
| 220 |
+
"repetition_3": _repetition_ratio(words, 3),
|
| 221 |
+
}
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
sample_count = len(samples)
|
| 225 |
+
return {
|
| 226 |
+
"sample_count": sample_count,
|
| 227 |
+
"reasoning_mode": reasoning_mode or model.config.default_reasoning_profile,
|
| 228 |
+
"generation_policy": {
|
| 229 |
+
"temperature": temperature,
|
| 230 |
+
"top_k": top_k,
|
| 231 |
+
"top_p": top_p,
|
| 232 |
+
"repetition_penalty": repetition_penalty,
|
| 233 |
+
},
|
| 234 |
+
"mean_word_count": (
|
| 235 |
+
sum(int(sample["word_count"]) for sample in samples) / sample_count
|
| 236 |
+
if sample_count
|
| 237 |
+
else 0.0
|
| 238 |
+
),
|
| 239 |
+
"mean_char_count": (
|
| 240 |
+
sum(int(sample["char_count"]) for sample in samples) / sample_count
|
| 241 |
+
if sample_count
|
| 242 |
+
else 0.0
|
| 243 |
+
),
|
| 244 |
+
"punctuation_rate": (
|
| 245 |
+
sum(1 for sample in samples if bool(sample["punctuation_hit"])) / sample_count
|
| 246 |
+
if sample_count
|
| 247 |
+
else 0.0
|
| 248 |
+
),
|
| 249 |
+
"mean_distinct_2": (
|
| 250 |
+
sum(float(sample["distinct_2"]) for sample in samples) / sample_count
|
| 251 |
+
if sample_count
|
| 252 |
+
else 0.0
|
| 253 |
+
),
|
| 254 |
+
"mean_distinct_3": (
|
| 255 |
+
sum(float(sample["distinct_3"]) for sample in samples) / sample_count
|
| 256 |
+
if sample_count
|
| 257 |
+
else 0.0
|
| 258 |
+
),
|
| 259 |
+
"mean_repetition_3": (
|
| 260 |
+
sum(float(sample["repetition_3"]) for sample in samples) / sample_count
|
| 261 |
+
if sample_count
|
| 262 |
+
else 0.0
|
| 263 |
+
),
|
| 264 |
+
"samples": samples,
|
| 265 |
+
}
|
reframr/hf_import.py
ADDED
|
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import site
|
| 4 |
+
import sys
|
| 5 |
+
from itertools import chain
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from .text_quality import clean_answer_text, clean_context_text, clean_training_text
|
| 9 |
+
|
| 10 |
+
TEXT_FIELD_PREFERENCES = (
|
| 11 |
+
"text",
|
| 12 |
+
"content",
|
| 13 |
+
"body",
|
| 14 |
+
"article",
|
| 15 |
+
"document",
|
| 16 |
+
"passage",
|
| 17 |
+
"markdown",
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
DIALOGUE_FIELD_PREFERENCES = (
|
| 21 |
+
"messages",
|
| 22 |
+
"conversation",
|
| 23 |
+
"conversations",
|
| 24 |
+
"dialogue",
|
| 25 |
+
"dialog",
|
| 26 |
+
"turns",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
PREFERENCE_FIELD_PAIRS = (
|
| 30 |
+
("chosen", "rejected"),
|
| 31 |
+
("response_j", "response_k"),
|
| 32 |
+
("response_0", "response_1"),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
INSTRUCTION_FIELD_PAIRS = (
|
| 36 |
+
("instruction", "output"),
|
| 37 |
+
("prompt", "completion"),
|
| 38 |
+
("prompt", "response"),
|
| 39 |
+
("question", "answer"),
|
| 40 |
+
("question", "response"),
|
| 41 |
+
("query", "response"),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
TRANSCRIPT_ROLE_PATTERN = re.compile(r"(?:^|\n\s*\n)(Human|Assistant|System)\s*:\s*", re.IGNORECASE)
|
| 45 |
+
ROLE_ALIASES = {
|
| 46 |
+
"assistant": "assistant",
|
| 47 |
+
"bot": "assistant",
|
| 48 |
+
"gpt": "assistant",
|
| 49 |
+
"model": "assistant",
|
| 50 |
+
"assistant_response": "assistant",
|
| 51 |
+
"human": "user",
|
| 52 |
+
"user": "user",
|
| 53 |
+
"prompter": "user",
|
| 54 |
+
"customer": "user",
|
| 55 |
+
"system": "system",
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _word_count(text: str) -> int:
|
| 60 |
+
return len(text.split())
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _alpha_ratio(text: str) -> float:
|
| 64 |
+
if not text:
|
| 65 |
+
return 0.0
|
| 66 |
+
alpha_count = sum(character.isalpha() for character in text)
|
| 67 |
+
return alpha_count / len(text)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _default_record_weight(record_type: str) -> int:
|
| 71 |
+
if record_type == "dialogue_turn":
|
| 72 |
+
return 2
|
| 73 |
+
if record_type == "instruction_answer":
|
| 74 |
+
return 2
|
| 75 |
+
if record_type == "preference_chosen":
|
| 76 |
+
return 3
|
| 77 |
+
if record_type == "preference_rejected":
|
| 78 |
+
return 0
|
| 79 |
+
return 1
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def choose_text_field(columns: list[str]) -> str:
|
| 83 |
+
normalized = {column.casefold(): column for column in columns}
|
| 84 |
+
for preferred in TEXT_FIELD_PREFERENCES:
|
| 85 |
+
if preferred in normalized:
|
| 86 |
+
return normalized[preferred]
|
| 87 |
+
raise ValueError("Could not infer a text column. Pass --text-field explicitly.")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def choose_dialogue_field(columns: list[str]) -> str:
|
| 91 |
+
normalized = {column.casefold(): column for column in columns}
|
| 92 |
+
for preferred in DIALOGUE_FIELD_PREFERENCES:
|
| 93 |
+
if preferred in normalized:
|
| 94 |
+
return normalized[preferred]
|
| 95 |
+
raise ValueError("Could not infer a conversation column.")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def choose_preference_fields(columns: list[str]) -> tuple[str, str]:
|
| 99 |
+
normalized = {column.casefold(): column for column in columns}
|
| 100 |
+
for chosen_name, rejected_name in PREFERENCE_FIELD_PAIRS:
|
| 101 |
+
if chosen_name in normalized and rejected_name in normalized:
|
| 102 |
+
return normalized[chosen_name], normalized[rejected_name]
|
| 103 |
+
raise ValueError("Could not infer chosen/rejected preference columns.")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def choose_instruction_fields(columns: list[str]) -> tuple[str, str]:
|
| 107 |
+
normalized = {column.casefold(): column for column in columns}
|
| 108 |
+
for prompt_name, answer_name in INSTRUCTION_FIELD_PAIRS:
|
| 109 |
+
if prompt_name in normalized and answer_name in normalized:
|
| 110 |
+
return normalized[prompt_name], normalized[answer_name]
|
| 111 |
+
raise ValueError("Could not infer instruction/answer columns.")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _row_identifier(row: dict[str, object]) -> str:
|
| 115 |
+
for candidate in ("id", "_id", "row_id", "uuid", "prompt_id"):
|
| 116 |
+
if candidate in row and str(row[candidate]).strip():
|
| 117 |
+
return str(row[candidate]).strip()
|
| 118 |
+
return ""
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _base_record(
|
| 122 |
+
*,
|
| 123 |
+
dataset: str,
|
| 124 |
+
config: str | None,
|
| 125 |
+
split: str,
|
| 126 |
+
row_id: str,
|
| 127 |
+
) -> dict[str, str]:
|
| 128 |
+
return {
|
| 129 |
+
"source": "huggingface",
|
| 130 |
+
"dataset": dataset,
|
| 131 |
+
"config": config or "",
|
| 132 |
+
"split": split,
|
| 133 |
+
"row_id": row_id,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _row_language(row: dict[str, object]) -> str:
|
| 138 |
+
for candidate in ("lang", "language", "locale"):
|
| 139 |
+
value = row.get(candidate)
|
| 140 |
+
if isinstance(value, str) and value.strip():
|
| 141 |
+
return value.strip()
|
| 142 |
+
return ""
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _normalize_role(raw_role: object) -> str:
|
| 146 |
+
role = str(raw_role or "").strip().casefold()
|
| 147 |
+
return ROLE_ALIASES.get(role, role)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _message_content(message: dict[str, object]) -> str:
|
| 151 |
+
for field in ("content", "value", "text", "message"):
|
| 152 |
+
value = message.get(field)
|
| 153 |
+
if isinstance(value, str) and value.strip():
|
| 154 |
+
return clean_training_text(value)
|
| 155 |
+
return ""
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _message_role(message: dict[str, object]) -> str:
|
| 159 |
+
for field in ("role", "from", "speaker", "author"):
|
| 160 |
+
value = message.get(field)
|
| 161 |
+
if value is not None:
|
| 162 |
+
normalized = _normalize_role(value)
|
| 163 |
+
if normalized:
|
| 164 |
+
return normalized
|
| 165 |
+
return ""
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _parse_dialogue_messages(raw_messages: object) -> list[dict[str, str]]:
|
| 169 |
+
if not isinstance(raw_messages, list):
|
| 170 |
+
return []
|
| 171 |
+
|
| 172 |
+
parsed: list[dict[str, str]] = []
|
| 173 |
+
for message in raw_messages:
|
| 174 |
+
if not isinstance(message, dict):
|
| 175 |
+
continue
|
| 176 |
+
role = _message_role(message)
|
| 177 |
+
content = _message_content(message)
|
| 178 |
+
if role not in {"system", "user", "assistant"} or not content:
|
| 179 |
+
continue
|
| 180 |
+
parsed.append({"role": role, "content": content})
|
| 181 |
+
return parsed
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _parse_transcript_messages(raw_text: object) -> list[dict[str, str]]:
|
| 185 |
+
if not isinstance(raw_text, str):
|
| 186 |
+
return []
|
| 187 |
+
|
| 188 |
+
text = raw_text.strip()
|
| 189 |
+
if not text:
|
| 190 |
+
return []
|
| 191 |
+
|
| 192 |
+
matches = list(TRANSCRIPT_ROLE_PATTERN.finditer(text))
|
| 193 |
+
if not matches:
|
| 194 |
+
return []
|
| 195 |
+
|
| 196 |
+
parsed: list[dict[str, str]] = []
|
| 197 |
+
for index, match in enumerate(matches):
|
| 198 |
+
role = _normalize_role(match.group(1))
|
| 199 |
+
start = match.end()
|
| 200 |
+
end = matches[index + 1].start() if index + 1 < len(matches) else len(text)
|
| 201 |
+
content = clean_training_text(text[start:end].strip())
|
| 202 |
+
if role in {"system", "user", "assistant"} and content:
|
| 203 |
+
parsed.append({"role": role, "content": content})
|
| 204 |
+
return parsed
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _render_prompt(messages: list[dict[str, str]]) -> str:
|
| 208 |
+
lines = []
|
| 209 |
+
for message in messages:
|
| 210 |
+
content = clean_context_text(message["content"])
|
| 211 |
+
if content:
|
| 212 |
+
lines.append(content)
|
| 213 |
+
return "\n".join(lines).strip()
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _compose_training_text(context: str, answer: str) -> str:
|
| 217 |
+
context = clean_context_text(context)
|
| 218 |
+
answer = clean_answer_text(answer)
|
| 219 |
+
return f"<reason> {context} <answer> {answer}".strip()
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def _compose_instruction_context(row: dict[str, object], prompt_field: str) -> str:
|
| 223 |
+
parts: list[str] = []
|
| 224 |
+
prompt = clean_context_text(str(row.get(prompt_field, "")).strip())
|
| 225 |
+
extra_input = clean_context_text(str(row.get("input", "")).strip())
|
| 226 |
+
if prompt:
|
| 227 |
+
parts.append(prompt)
|
| 228 |
+
if extra_input:
|
| 229 |
+
parts.append(extra_input)
|
| 230 |
+
return "\n".join(parts).strip()
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _extract_prompt_answer(
|
| 234 |
+
row: dict[str, object],
|
| 235 |
+
*,
|
| 236 |
+
field_name: str,
|
| 237 |
+
) -> tuple[str, str]:
|
| 238 |
+
dialogue_messages = _parse_dialogue_messages(row.get(field_name))
|
| 239 |
+
if dialogue_messages and dialogue_messages[-1]["role"] == "assistant":
|
| 240 |
+
prompt = _render_prompt(dialogue_messages[:-1])
|
| 241 |
+
answer = dialogue_messages[-1]["content"]
|
| 242 |
+
if prompt and answer:
|
| 243 |
+
return prompt, answer
|
| 244 |
+
|
| 245 |
+
messages = _parse_transcript_messages(row.get(field_name))
|
| 246 |
+
if messages:
|
| 247 |
+
if messages[-1]["role"] == "assistant":
|
| 248 |
+
prompt = _render_prompt(messages[:-1])
|
| 249 |
+
answer = messages[-1]["content"]
|
| 250 |
+
if prompt and answer:
|
| 251 |
+
return prompt, answer
|
| 252 |
+
|
| 253 |
+
prompt = clean_training_text(str(row.get("prompt", row.get("question", ""))).strip())
|
| 254 |
+
answer = clean_answer_text(str(row.get(field_name, "")).strip())
|
| 255 |
+
return prompt, answer
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def _ordered_preference_fields(
|
| 259 |
+
row: dict[str, object],
|
| 260 |
+
*,
|
| 261 |
+
left_field: str,
|
| 262 |
+
right_field: str,
|
| 263 |
+
) -> tuple[str, str]:
|
| 264 |
+
if {left_field, right_field} != {"response_0", "response_1"}:
|
| 265 |
+
return left_field, right_field
|
| 266 |
+
|
| 267 |
+
for selector in ("safer_response_id", "better_response_id"):
|
| 268 |
+
value = row.get(selector)
|
| 269 |
+
try:
|
| 270 |
+
preferred = int(value)
|
| 271 |
+
except (TypeError, ValueError):
|
| 272 |
+
continue
|
| 273 |
+
if preferred == 0:
|
| 274 |
+
return "response_0", "response_1"
|
| 275 |
+
if preferred == 1:
|
| 276 |
+
return "response_1", "response_0"
|
| 277 |
+
return left_field, right_field
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def _passes_quality_gate(
|
| 281 |
+
record: dict[str, str],
|
| 282 |
+
*,
|
| 283 |
+
min_words: int,
|
| 284 |
+
max_words: int,
|
| 285 |
+
min_alpha_ratio: float,
|
| 286 |
+
allowed_languages: set[str],
|
| 287 |
+
) -> bool:
|
| 288 |
+
candidate = str(record.get("answer") or record.get("text") or "").strip()
|
| 289 |
+
if not candidate:
|
| 290 |
+
return False
|
| 291 |
+
|
| 292 |
+
word_count = _word_count(candidate)
|
| 293 |
+
if min_words > 0 and word_count < min_words:
|
| 294 |
+
return False
|
| 295 |
+
if max_words > 0 and word_count > max_words:
|
| 296 |
+
return False
|
| 297 |
+
|
| 298 |
+
alpha_ratio = _alpha_ratio(candidate)
|
| 299 |
+
if min_alpha_ratio > 0.0 and alpha_ratio < min_alpha_ratio:
|
| 300 |
+
return False
|
| 301 |
+
|
| 302 |
+
if allowed_languages:
|
| 303 |
+
language = str(record.get("language", "")).strip().casefold()
|
| 304 |
+
if not language or language not in allowed_languages:
|
| 305 |
+
return False
|
| 306 |
+
|
| 307 |
+
record["quality_word_count"] = str(word_count)
|
| 308 |
+
record["quality_alpha_ratio"] = f"{alpha_ratio:.4f}"
|
| 309 |
+
return True
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def to_json_record(
|
| 313 |
+
*,
|
| 314 |
+
dataset: str,
|
| 315 |
+
config: str | None,
|
| 316 |
+
split: str,
|
| 317 |
+
text_field: str,
|
| 318 |
+
row: dict[str, object],
|
| 319 |
+
) -> dict[str, str]:
|
| 320 |
+
text = clean_training_text(str(row.get(text_field, "")).strip())
|
| 321 |
+
if not text:
|
| 322 |
+
raise ValueError("Row is missing usable text.")
|
| 323 |
+
|
| 324 |
+
record_type = "text"
|
| 325 |
+
return {
|
| 326 |
+
**_base_record(
|
| 327 |
+
dataset=dataset,
|
| 328 |
+
config=config,
|
| 329 |
+
split=split,
|
| 330 |
+
row_id=_row_identifier(row),
|
| 331 |
+
),
|
| 332 |
+
"record_type": record_type,
|
| 333 |
+
"language": _row_language(row),
|
| 334 |
+
"text_field": text_field,
|
| 335 |
+
"text": text,
|
| 336 |
+
"word_count": _word_count(text),
|
| 337 |
+
"weight": _default_record_weight(record_type),
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def dialogue_to_json_records(
|
| 342 |
+
*,
|
| 343 |
+
dataset: str,
|
| 344 |
+
config: str | None,
|
| 345 |
+
split: str,
|
| 346 |
+
conversation_field: str,
|
| 347 |
+
row: dict[str, object],
|
| 348 |
+
) -> list[dict[str, str]]:
|
| 349 |
+
messages = _parse_dialogue_messages(row.get(conversation_field))
|
| 350 |
+
if not messages:
|
| 351 |
+
raise ValueError("Row does not contain usable dialogue turns.")
|
| 352 |
+
|
| 353 |
+
row_id = _row_identifier(row)
|
| 354 |
+
records: list[dict[str, str]] = []
|
| 355 |
+
history: list[dict[str, str]] = []
|
| 356 |
+
row_language = _row_language(row)
|
| 357 |
+
system_text = clean_training_text(str(row.get("system", "")).strip())
|
| 358 |
+
if system_text:
|
| 359 |
+
history.append({"role": "system", "content": system_text})
|
| 360 |
+
assistant_turn_index = 0
|
| 361 |
+
for message in messages:
|
| 362 |
+
if message["role"] != "assistant":
|
| 363 |
+
history.append(message)
|
| 364 |
+
continue
|
| 365 |
+
prompt = _render_prompt(history)
|
| 366 |
+
if not prompt:
|
| 367 |
+
continue
|
| 368 |
+
assistant_turn_index += 1
|
| 369 |
+
records.append(
|
| 370 |
+
{
|
| 371 |
+
**_base_record(
|
| 372 |
+
dataset=dataset,
|
| 373 |
+
config=config,
|
| 374 |
+
split=split,
|
| 375 |
+
row_id=row_id,
|
| 376 |
+
),
|
| 377 |
+
"record_type": "dialogue_turn",
|
| 378 |
+
"language": row_language,
|
| 379 |
+
"conversation_field": conversation_field,
|
| 380 |
+
"turn_index": str(assistant_turn_index),
|
| 381 |
+
"context": prompt,
|
| 382 |
+
"answer": clean_answer_text(message["content"]),
|
| 383 |
+
"text": _compose_training_text(prompt, message["content"]),
|
| 384 |
+
"word_count": _word_count(clean_answer_text(message["content"])),
|
| 385 |
+
"weight": _default_record_weight("dialogue_turn"),
|
| 386 |
+
}
|
| 387 |
+
)
|
| 388 |
+
history.append(message)
|
| 389 |
+
|
| 390 |
+
if not records:
|
| 391 |
+
raise ValueError("Dialogue row did not yield any assistant training turns.")
|
| 392 |
+
return records
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def preference_to_json_records(
|
| 396 |
+
*,
|
| 397 |
+
dataset: str,
|
| 398 |
+
config: str | None,
|
| 399 |
+
split: str,
|
| 400 |
+
chosen_field: str,
|
| 401 |
+
rejected_field: str,
|
| 402 |
+
row: dict[str, object],
|
| 403 |
+
preference_target: str = "both",
|
| 404 |
+
) -> list[dict[str, str]]:
|
| 405 |
+
row_id = _row_identifier(row)
|
| 406 |
+
pair_id = row_id or f"{chosen_field}:{rejected_field}"
|
| 407 |
+
records: list[dict[str, str]] = []
|
| 408 |
+
row_language = _row_language(row)
|
| 409 |
+
chosen_field, rejected_field = _ordered_preference_fields(
|
| 410 |
+
row,
|
| 411 |
+
left_field=chosen_field,
|
| 412 |
+
right_field=rejected_field,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
field_specs = [
|
| 416 |
+
(chosen_field, "preference_chosen"),
|
| 417 |
+
(rejected_field, "preference_rejected"),
|
| 418 |
+
]
|
| 419 |
+
if preference_target == "chosen":
|
| 420 |
+
field_specs = [(chosen_field, "preference_chosen")]
|
| 421 |
+
elif preference_target == "rejected":
|
| 422 |
+
field_specs = [(rejected_field, "preference_rejected")]
|
| 423 |
+
elif preference_target != "both":
|
| 424 |
+
raise ValueError("preference_target must be one of: both, chosen, rejected.")
|
| 425 |
+
|
| 426 |
+
for field_name, record_type in field_specs:
|
| 427 |
+
prompt, answer = _extract_prompt_answer(row, field_name=field_name)
|
| 428 |
+
if not prompt or not answer:
|
| 429 |
+
continue
|
| 430 |
+
records.append(
|
| 431 |
+
{
|
| 432 |
+
**_base_record(
|
| 433 |
+
dataset=dataset,
|
| 434 |
+
config=config,
|
| 435 |
+
split=split,
|
| 436 |
+
row_id=row_id,
|
| 437 |
+
),
|
| 438 |
+
"record_type": record_type,
|
| 439 |
+
"language": row_language,
|
| 440 |
+
"pair_id": pair_id,
|
| 441 |
+
"text_field": field_name,
|
| 442 |
+
"context": prompt,
|
| 443 |
+
"answer": clean_answer_text(answer),
|
| 444 |
+
"text": _compose_training_text(prompt, answer),
|
| 445 |
+
"word_count": _word_count(clean_answer_text(answer)),
|
| 446 |
+
"weight": _default_record_weight(record_type),
|
| 447 |
+
}
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
if not records:
|
| 451 |
+
raise ValueError("Preference row did not yield usable chosen/rejected transcripts.")
|
| 452 |
+
return records
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def instruction_to_json_records(
|
| 456 |
+
*,
|
| 457 |
+
dataset: str,
|
| 458 |
+
config: str | None,
|
| 459 |
+
split: str,
|
| 460 |
+
prompt_field: str,
|
| 461 |
+
answer_field: str,
|
| 462 |
+
row: dict[str, object],
|
| 463 |
+
) -> list[dict[str, str]]:
|
| 464 |
+
context = _compose_instruction_context(row, prompt_field)
|
| 465 |
+
answer = clean_answer_text(str(row.get(answer_field, "")).strip())
|
| 466 |
+
if not context or not answer:
|
| 467 |
+
raise ValueError("Instruction row did not contain usable prompt and answer text.")
|
| 468 |
+
record_type = "instruction_answer"
|
| 469 |
+
return [
|
| 470 |
+
{
|
| 471 |
+
**_base_record(
|
| 472 |
+
dataset=dataset,
|
| 473 |
+
config=config,
|
| 474 |
+
split=split,
|
| 475 |
+
row_id=_row_identifier(row),
|
| 476 |
+
),
|
| 477 |
+
"record_type": record_type,
|
| 478 |
+
"language": _row_language(row),
|
| 479 |
+
"context": context,
|
| 480 |
+
"answer": answer,
|
| 481 |
+
"text": _compose_training_text(context, answer),
|
| 482 |
+
"word_count": _word_count(answer),
|
| 483 |
+
"weight": _default_record_weight(record_type),
|
| 484 |
+
}
|
| 485 |
+
]
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def _expand_row_records(
|
| 489 |
+
*,
|
| 490 |
+
dataset: str,
|
| 491 |
+
config: str | None,
|
| 492 |
+
split: str,
|
| 493 |
+
row: dict[str, object],
|
| 494 |
+
text_field: str | None,
|
| 495 |
+
preference_target: str,
|
| 496 |
+
) -> list[dict[str, str]]:
|
| 497 |
+
if text_field is not None:
|
| 498 |
+
explicit_value = row.get(text_field)
|
| 499 |
+
if isinstance(explicit_value, list):
|
| 500 |
+
return dialogue_to_json_records(
|
| 501 |
+
dataset=dataset,
|
| 502 |
+
config=config,
|
| 503 |
+
split=split,
|
| 504 |
+
conversation_field=text_field,
|
| 505 |
+
row=row,
|
| 506 |
+
)
|
| 507 |
+
return [
|
| 508 |
+
to_json_record(
|
| 509 |
+
dataset=dataset,
|
| 510 |
+
config=config,
|
| 511 |
+
split=split,
|
| 512 |
+
text_field=text_field,
|
| 513 |
+
row=row,
|
| 514 |
+
)
|
| 515 |
+
]
|
| 516 |
+
|
| 517 |
+
columns = list(row)
|
| 518 |
+
try:
|
| 519 |
+
chosen_field, rejected_field = choose_preference_fields(columns)
|
| 520 |
+
return preference_to_json_records(
|
| 521 |
+
dataset=dataset,
|
| 522 |
+
config=config,
|
| 523 |
+
split=split,
|
| 524 |
+
chosen_field=chosen_field,
|
| 525 |
+
rejected_field=rejected_field,
|
| 526 |
+
row=row,
|
| 527 |
+
preference_target=preference_target,
|
| 528 |
+
)
|
| 529 |
+
except ValueError:
|
| 530 |
+
pass
|
| 531 |
+
|
| 532 |
+
try:
|
| 533 |
+
prompt_field, answer_field = choose_instruction_fields(columns)
|
| 534 |
+
return instruction_to_json_records(
|
| 535 |
+
dataset=dataset,
|
| 536 |
+
config=config,
|
| 537 |
+
split=split,
|
| 538 |
+
prompt_field=prompt_field,
|
| 539 |
+
answer_field=answer_field,
|
| 540 |
+
row=row,
|
| 541 |
+
)
|
| 542 |
+
except ValueError:
|
| 543 |
+
pass
|
| 544 |
+
|
| 545 |
+
try:
|
| 546 |
+
conversation_field = choose_dialogue_field(columns)
|
| 547 |
+
if isinstance(row.get(conversation_field), list):
|
| 548 |
+
return dialogue_to_json_records(
|
| 549 |
+
dataset=dataset,
|
| 550 |
+
config=config,
|
| 551 |
+
split=split,
|
| 552 |
+
conversation_field=conversation_field,
|
| 553 |
+
row=row,
|
| 554 |
+
)
|
| 555 |
+
except ValueError:
|
| 556 |
+
pass
|
| 557 |
+
|
| 558 |
+
inferred_text_field = choose_text_field(columns)
|
| 559 |
+
return [
|
| 560 |
+
to_json_record(
|
| 561 |
+
dataset=dataset,
|
| 562 |
+
config=config,
|
| 563 |
+
split=split,
|
| 564 |
+
text_field=inferred_text_field,
|
| 565 |
+
row=row,
|
| 566 |
+
)
|
| 567 |
+
]
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def import_hf_dataset(
|
| 571 |
+
*,
|
| 572 |
+
dataset: str,
|
| 573 |
+
output_path: str | Path,
|
| 574 |
+
config: str | None = None,
|
| 575 |
+
split: str = "train",
|
| 576 |
+
text_field: str | None = None,
|
| 577 |
+
limit: int = 1000,
|
| 578 |
+
streaming: bool = True,
|
| 579 |
+
preference_target: str = "chosen",
|
| 580 |
+
min_words: int = 0,
|
| 581 |
+
max_words: int = 0,
|
| 582 |
+
min_alpha_ratio: float = 0.0,
|
| 583 |
+
allowed_languages: tuple[str, ...] = (),
|
| 584 |
+
) -> dict[str, object]:
|
| 585 |
+
try:
|
| 586 |
+
from datasets import load_dataset
|
| 587 |
+
except ModuleNotFoundError:
|
| 588 |
+
user_site = site.getusersitepackages()
|
| 589 |
+
if user_site and user_site not in sys.path:
|
| 590 |
+
sys.path.append(user_site)
|
| 591 |
+
from datasets import load_dataset
|
| 592 |
+
|
| 593 |
+
dataset_kwargs: dict[str, object] = {
|
| 594 |
+
"split": split,
|
| 595 |
+
"streaming": streaming,
|
| 596 |
+
}
|
| 597 |
+
if config:
|
| 598 |
+
dataset_kwargs["name"] = config
|
| 599 |
+
|
| 600 |
+
hf_dataset = load_dataset(dataset, **dataset_kwargs)
|
| 601 |
+
iterator = iter(hf_dataset)
|
| 602 |
+
|
| 603 |
+
first_row: dict[str, object] | None = None
|
| 604 |
+
if text_field is None:
|
| 605 |
+
first_row = dict(next(iterator))
|
| 606 |
+
iterator = chain([first_row], iterator)
|
| 607 |
+
|
| 608 |
+
output = Path(output_path)
|
| 609 |
+
output.parent.mkdir(parents=True, exist_ok=True)
|
| 610 |
+
|
| 611 |
+
written = 0
|
| 612 |
+
record_types: set[str] = set()
|
| 613 |
+
normalized_languages = {language.casefold() for language in allowed_languages if language.strip()}
|
| 614 |
+
with output.open("w", encoding="utf-8") as handle:
|
| 615 |
+
for row in iterator:
|
| 616 |
+
if written >= limit:
|
| 617 |
+
break
|
| 618 |
+
normalized_row = dict(row)
|
| 619 |
+
try:
|
| 620 |
+
records = _expand_row_records(
|
| 621 |
+
dataset=dataset,
|
| 622 |
+
config=config,
|
| 623 |
+
split=split,
|
| 624 |
+
row=normalized_row,
|
| 625 |
+
text_field=text_field,
|
| 626 |
+
preference_target=preference_target,
|
| 627 |
+
)
|
| 628 |
+
except ValueError:
|
| 629 |
+
continue
|
| 630 |
+
|
| 631 |
+
for record in records:
|
| 632 |
+
if written >= limit:
|
| 633 |
+
break
|
| 634 |
+
if not _passes_quality_gate(
|
| 635 |
+
record,
|
| 636 |
+
min_words=min_words,
|
| 637 |
+
max_words=max_words,
|
| 638 |
+
min_alpha_ratio=min_alpha_ratio,
|
| 639 |
+
allowed_languages=normalized_languages,
|
| 640 |
+
):
|
| 641 |
+
continue
|
| 642 |
+
record_types.add(record.get("record_type", "text"))
|
| 643 |
+
handle.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 644 |
+
written += 1
|
| 645 |
+
|
| 646 |
+
inferred_mode = "mixed" if len(record_types) > 1 else (next(iter(record_types)) if record_types else "unknown")
|
| 647 |
+
return {
|
| 648 |
+
"dataset": dataset,
|
| 649 |
+
"config": config or "",
|
| 650 |
+
"split": split,
|
| 651 |
+
"text_field": text_field or "",
|
| 652 |
+
"output_path": str(output.resolve()),
|
| 653 |
+
"records_written": written,
|
| 654 |
+
"record_types": sorted(record_types),
|
| 655 |
+
"mode": inferred_mode,
|
| 656 |
+
"preference_target": preference_target,
|
| 657 |
+
"streaming": streaming,
|
| 658 |
+
"min_words": min_words,
|
| 659 |
+
"max_words": max_words,
|
| 660 |
+
"min_alpha_ratio": min_alpha_ratio,
|
| 661 |
+
"allowed_languages": sorted(normalized_languages),
|
| 662 |
+
}
|
reframr/hippo.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import site
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from .linalg import Matrix, Vector, identity, invert_matrix, matvec
|
| 8 |
+
|
| 9 |
+
_VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
|
| 10 |
+
for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
|
| 11 |
+
if _vendor_path.exists():
|
| 12 |
+
vendor_text = str(_vendor_path)
|
| 13 |
+
if vendor_text not in sys.path:
|
| 14 |
+
sys.path.insert(0, vendor_text)
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import numpy as np
|
| 18 |
+
except ModuleNotFoundError:
|
| 19 |
+
user_site = site.getusersitepackages()
|
| 20 |
+
if user_site and user_site not in sys.path:
|
| 21 |
+
sys.path.append(user_site)
|
| 22 |
+
try:
|
| 23 |
+
import numpy as np
|
| 24 |
+
except ModuleNotFoundError:
|
| 25 |
+
np = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def hippo_legs_matrix(order: int) -> tuple[Matrix, Vector]:
|
| 29 |
+
a_matrix = [[0.0 for _ in range(order)] for _ in range(order)]
|
| 30 |
+
b_vector = [0.0 for _ in range(order)]
|
| 31 |
+
|
| 32 |
+
for row in range(order):
|
| 33 |
+
for col in range(order):
|
| 34 |
+
if row > col:
|
| 35 |
+
a_matrix[row][col] = -math.sqrt(2 * row + 1) * math.sqrt(2 * col + 1)
|
| 36 |
+
elif row == col:
|
| 37 |
+
a_matrix[row][col] = -(row + 1)
|
| 38 |
+
b_vector[row] = math.sqrt(2 * row + 1)
|
| 39 |
+
|
| 40 |
+
return a_matrix, b_vector
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def analytical_embedding_drive(embedding: Vector, state_dim: int) -> Vector:
|
| 44 |
+
if not embedding:
|
| 45 |
+
return [0.0 for _ in range(state_dim)]
|
| 46 |
+
width = len(embedding)
|
| 47 |
+
return [
|
| 48 |
+
(
|
| 49 |
+
embedding[index % width]
|
| 50 |
+
+ 0.5 * embedding[(3 * index + 1) % width]
|
| 51 |
+
- 0.25 * embedding[(5 * index + 2) % width]
|
| 52 |
+
)
|
| 53 |
+
for index in range(state_dim)
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def analytical_embedding_drive_fast(embedding: object, state_dim: int) -> object:
|
| 58 |
+
if np is None:
|
| 59 |
+
embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else list(embedding)
|
| 60 |
+
return analytical_embedding_drive(embedding_vector, state_dim)
|
| 61 |
+
embedding_array = embedding if hasattr(embedding, "shape") else np.asarray(embedding, dtype=np.float64)
|
| 62 |
+
if embedding_array.size == 0:
|
| 63 |
+
return np.zeros(state_dim, dtype=np.float64)
|
| 64 |
+
indices = np.arange(state_dim, dtype=np.int64)
|
| 65 |
+
width = int(embedding_array.shape[0])
|
| 66 |
+
return (
|
| 67 |
+
embedding_array[indices % width]
|
| 68 |
+
+ 0.5 * embedding_array[(3 * indices + 1) % width]
|
| 69 |
+
- 0.25 * embedding_array[(5 * indices + 2) % width]
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass(slots=True)
|
| 74 |
+
class AnalyticalMemoryUnit:
|
| 75 |
+
state_dim: int
|
| 76 |
+
timescale: float
|
| 77 |
+
|
| 78 |
+
def __post_init__(self) -> None:
|
| 79 |
+
a_matrix, b_vector = hippo_legs_matrix(self.state_dim)
|
| 80 |
+
self.transition, self.input_projection = self._discretize_transition(
|
| 81 |
+
a_matrix,
|
| 82 |
+
b_vector,
|
| 83 |
+
self.timescale,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
transition: Matrix = None # type: ignore[assignment]
|
| 87 |
+
input_projection: Vector = None # type: ignore[assignment]
|
| 88 |
+
transition_array: object | None = None # type: ignore[assignment]
|
| 89 |
+
input_projection_array: object | None = None # type: ignore[assignment]
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def _discretize_transition(
|
| 93 |
+
a_matrix: Matrix,
|
| 94 |
+
b_vector: Vector,
|
| 95 |
+
step: float,
|
| 96 |
+
) -> tuple[Matrix, Vector]:
|
| 97 |
+
implicit_system = [
|
| 98 |
+
[
|
| 99 |
+
identity_value - step * a_value
|
| 100 |
+
for identity_value, a_value in zip(identity_row, a_row)
|
| 101 |
+
]
|
| 102 |
+
for identity_row, a_row in zip(identity(len(a_matrix)), a_matrix)
|
| 103 |
+
]
|
| 104 |
+
transition = invert_matrix(implicit_system)
|
| 105 |
+
input_projection = matvec(transition, [step * value for value in b_vector])
|
| 106 |
+
return transition, input_projection
|
| 107 |
+
|
| 108 |
+
def step(self, state: Vector, scalar_input: float) -> Vector:
|
| 109 |
+
if np is not None and self.transition_array is None:
|
| 110 |
+
self.transition_array = np.asarray(self.transition, dtype=np.float64)
|
| 111 |
+
self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64)
|
| 112 |
+
propagated = matvec(self.transition, state)
|
| 113 |
+
return [
|
| 114 |
+
propagated[index] + self.input_projection[index] * scalar_input
|
| 115 |
+
for index in range(self.state_dim)
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
def step_vector(self, state: Vector, drive: Vector) -> Vector:
|
| 119 |
+
propagated = matvec(self.transition, state)
|
| 120 |
+
return [
|
| 121 |
+
propagated[index] + self.input_projection[index] * drive[index]
|
| 122 |
+
for index in range(self.state_dim)
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
def step_fast(self, state: object, scalar_input: float) -> object:
|
| 126 |
+
if np is None:
|
| 127 |
+
state_vector = state.tolist() if hasattr(state, "tolist") else list(state)
|
| 128 |
+
return self.step(state_vector, scalar_input)
|
| 129 |
+
if self.transition_array is None or self.input_projection_array is None:
|
| 130 |
+
self.transition_array = np.asarray(self.transition, dtype=np.float64)
|
| 131 |
+
self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64)
|
| 132 |
+
state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64)
|
| 133 |
+
return (self.transition_array @ state_array) + (self.input_projection_array * scalar_input)
|
| 134 |
+
|
| 135 |
+
def step_vector_fast(self, state: object, drive: object) -> object:
|
| 136 |
+
if np is None:
|
| 137 |
+
state_vector = state.tolist() if hasattr(state, "tolist") else list(state)
|
| 138 |
+
drive_vector = drive.tolist() if hasattr(drive, "tolist") else list(drive)
|
| 139 |
+
return self.step_vector(state_vector, drive_vector)
|
| 140 |
+
if self.transition_array is None or self.input_projection_array is None:
|
| 141 |
+
self.transition_array = np.asarray(self.transition, dtype=np.float64)
|
| 142 |
+
self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64)
|
| 143 |
+
state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64)
|
| 144 |
+
drive_array = drive if hasattr(drive, "shape") else np.asarray(drive, dtype=np.float64)
|
| 145 |
+
return (self.transition_array @ state_array) + (self.input_projection_array * drive_array)
|
reframr/linalg.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import site
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
_VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
|
| 7 |
+
for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
|
| 8 |
+
if _vendor_path.exists():
|
| 9 |
+
vendor_text = str(_vendor_path)
|
| 10 |
+
if vendor_text not in sys.path:
|
| 11 |
+
sys.path.insert(0, vendor_text)
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import numpy as np
|
| 15 |
+
except ModuleNotFoundError:
|
| 16 |
+
user_site = site.getusersitepackages()
|
| 17 |
+
if user_site and user_site not in sys.path:
|
| 18 |
+
sys.path.append(user_site)
|
| 19 |
+
try:
|
| 20 |
+
import numpy as np
|
| 21 |
+
except ModuleNotFoundError:
|
| 22 |
+
np = None
|
| 23 |
+
|
| 24 |
+
if np is not None and not hasattr(np, "asarray"):
|
| 25 |
+
np = None
|
| 26 |
+
|
| 27 |
+
Matrix = list[list[float]]
|
| 28 |
+
Vector = list[float]
|
| 29 |
+
SUMPROD = getattr(math, "sumprod", None)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def zeros(rows: int, cols: int) -> Matrix:
|
| 33 |
+
return [[0.0 for _ in range(cols)] for _ in range(rows)]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def zeros_vector(size: int) -> Vector:
|
| 37 |
+
return [0.0 for _ in range(size)]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def identity(size: int) -> Matrix:
|
| 41 |
+
matrix = zeros(size, size)
|
| 42 |
+
for index in range(size):
|
| 43 |
+
matrix[index][index] = 1.0
|
| 44 |
+
return matrix
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def copy_matrix(matrix: Matrix) -> Matrix:
|
| 48 |
+
return [row[:] for row in matrix]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def transpose(matrix: Matrix) -> Matrix:
|
| 52 |
+
if not matrix:
|
| 53 |
+
return []
|
| 54 |
+
if np is not None:
|
| 55 |
+
return np.asarray(matrix, dtype=np.float64).T.tolist()
|
| 56 |
+
return [list(column) for column in zip(*matrix)]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def matvec(matrix: Matrix, vector: Vector) -> Vector:
|
| 60 |
+
if np is not None:
|
| 61 |
+
return (np.asarray(matrix, dtype=np.float64) @ np.asarray(vector, dtype=np.float64)).tolist()
|
| 62 |
+
if SUMPROD is not None:
|
| 63 |
+
return [SUMPROD(row, vector) for row in matrix]
|
| 64 |
+
return [sum(value * vector[idx] for idx, value in enumerate(row)) for row in matrix]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def matmul(left: Matrix, right: Matrix) -> Matrix:
|
| 68 |
+
if not left or not right:
|
| 69 |
+
return []
|
| 70 |
+
if np is not None:
|
| 71 |
+
return (np.asarray(left, dtype=np.float64) @ np.asarray(right, dtype=np.float64)).tolist()
|
| 72 |
+
right_t = transpose(right)
|
| 73 |
+
if SUMPROD is not None:
|
| 74 |
+
return [[SUMPROD(row, column) for column in right_t] for row in left]
|
| 75 |
+
return [
|
| 76 |
+
[sum(a * b for a, b in zip(row, column)) for column in right_t]
|
| 77 |
+
for row in left
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def add_matrices(left: Matrix, right: Matrix) -> Matrix:
|
| 82 |
+
return [
|
| 83 |
+
[left[row][col] + right[row][col] for col in range(len(left[row]))]
|
| 84 |
+
for row in range(len(left))
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def subtract_matrices(left: Matrix, right: Matrix) -> Matrix:
|
| 89 |
+
return [
|
| 90 |
+
[left[row][col] - right[row][col] for col in range(len(left[row]))]
|
| 91 |
+
for row in range(len(left))
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def scale_matrix(matrix: Matrix, scalar: float) -> Matrix:
|
| 96 |
+
return [[scalar * value for value in row] for row in matrix]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def dot(left: Vector, right: Vector) -> float:
|
| 100 |
+
if np is not None:
|
| 101 |
+
return float(np.dot(np.asarray(left, dtype=np.float64), np.asarray(right, dtype=np.float64)))
|
| 102 |
+
if SUMPROD is not None:
|
| 103 |
+
return SUMPROD(left, right)
|
| 104 |
+
return sum(a * b for a, b in zip(left, right))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def norm(vector: Vector) -> float:
|
| 108 |
+
return math.sqrt(dot(vector, vector))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def outer(left: Vector, right: Vector) -> Matrix:
|
| 112 |
+
if np is not None:
|
| 113 |
+
return np.outer(np.asarray(left, dtype=np.float64), np.asarray(right, dtype=np.float64)).tolist()
|
| 114 |
+
return [[a * b for b in right] for a in left]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def mean(values: Vector) -> float:
|
| 118 |
+
return sum(values) / len(values) if values else 0.0
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def trace(matrix: Matrix) -> float:
|
| 122 |
+
return sum(matrix[index][index] for index in range(min(len(matrix), len(matrix[0]))))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def covariance_matrix(samples: list[Vector]) -> Matrix:
|
| 126 |
+
if not samples:
|
| 127 |
+
return []
|
| 128 |
+
if np is not None:
|
| 129 |
+
sample_array = np.asarray(samples, dtype=np.float64)
|
| 130 |
+
centered = sample_array - sample_array.mean(axis=0, keepdims=True)
|
| 131 |
+
denominator = max(len(samples) - 1, 1)
|
| 132 |
+
return ((centered.T @ centered) / denominator).tolist()
|
| 133 |
+
|
| 134 |
+
feature_count = len(samples[0])
|
| 135 |
+
sample_count = len(samples)
|
| 136 |
+
means = [
|
| 137 |
+
sum(sample[feature] for sample in samples) / sample_count
|
| 138 |
+
for feature in range(feature_count)
|
| 139 |
+
]
|
| 140 |
+
covariance = zeros(feature_count, feature_count)
|
| 141 |
+
for sample in samples:
|
| 142 |
+
centered = [sample[index] - means[index] for index in range(feature_count)]
|
| 143 |
+
for row in range(feature_count):
|
| 144 |
+
for col in range(feature_count):
|
| 145 |
+
covariance[row][col] += centered[row] * centered[col]
|
| 146 |
+
|
| 147 |
+
denominator = max(sample_count - 1, 1)
|
| 148 |
+
return scale_matrix(covariance, 1.0 / denominator)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def solve_linear_system(matrix: Matrix, vector: Vector) -> Vector:
|
| 152 |
+
if np is not None:
|
| 153 |
+
return np.linalg.solve(
|
| 154 |
+
np.asarray(matrix, dtype=np.float64),
|
| 155 |
+
np.asarray(vector, dtype=np.float64),
|
| 156 |
+
).tolist()
|
| 157 |
+
size = len(matrix)
|
| 158 |
+
augmented = [matrix[row][:] + [vector[row]] for row in range(size)]
|
| 159 |
+
|
| 160 |
+
for pivot_index in range(size):
|
| 161 |
+
pivot_row = max(
|
| 162 |
+
range(pivot_index, size),
|
| 163 |
+
key=lambda row_index: abs(augmented[row_index][pivot_index]),
|
| 164 |
+
)
|
| 165 |
+
augmented[pivot_index], augmented[pivot_row] = augmented[pivot_row], augmented[pivot_index]
|
| 166 |
+
|
| 167 |
+
pivot_value = augmented[pivot_index][pivot_index]
|
| 168 |
+
if abs(pivot_value) < 1e-12:
|
| 169 |
+
raise ValueError("Singular matrix encountered while solving linear system.")
|
| 170 |
+
|
| 171 |
+
inverse_pivot = 1.0 / pivot_value
|
| 172 |
+
augmented[pivot_index] = [value * inverse_pivot for value in augmented[pivot_index]]
|
| 173 |
+
|
| 174 |
+
for row_index in range(size):
|
| 175 |
+
if row_index == pivot_index:
|
| 176 |
+
continue
|
| 177 |
+
factor = augmented[row_index][pivot_index]
|
| 178 |
+
augmented[row_index] = [
|
| 179 |
+
augmented[row_index][col] - factor * augmented[pivot_index][col]
|
| 180 |
+
for col in range(size + 1)
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
return [augmented[row][-1] for row in range(size)]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def invert_matrix(matrix: Matrix) -> Matrix:
|
| 187 |
+
if np is not None:
|
| 188 |
+
return np.linalg.inv(np.asarray(matrix, dtype=np.float64)).tolist()
|
| 189 |
+
size = len(matrix)
|
| 190 |
+
inverse_columns = []
|
| 191 |
+
for basis_index in range(size):
|
| 192 |
+
basis_vector = [0.0 for _ in range(size)]
|
| 193 |
+
basis_vector[basis_index] = 1.0
|
| 194 |
+
inverse_columns.append(solve_linear_system(matrix, basis_vector))
|
| 195 |
+
return transpose(inverse_columns)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def dominant_eigenpair_symmetric(
|
| 199 |
+
matrix: Matrix,
|
| 200 |
+
max_iterations: int = 64,
|
| 201 |
+
tolerance: float = 1e-10,
|
| 202 |
+
) -> tuple[float, Vector]:
|
| 203 |
+
size = len(matrix)
|
| 204 |
+
if size == 0:
|
| 205 |
+
return 0.0, []
|
| 206 |
+
if np is not None:
|
| 207 |
+
values, vectors = np.linalg.eigh(np.asarray(matrix, dtype=np.float64))
|
| 208 |
+
index = int(np.argmax(values))
|
| 209 |
+
eigenvalue = float(values[index])
|
| 210 |
+
if eigenvalue <= tolerance:
|
| 211 |
+
return 0.0, zeros_vector(size)
|
| 212 |
+
return eigenvalue, vectors[:, index].astype(float).tolist()
|
| 213 |
+
|
| 214 |
+
vector = [1.0 / math.sqrt(size) for _ in range(size)]
|
| 215 |
+
for _ in range(max_iterations):
|
| 216 |
+
next_vector = matvec(matrix, vector)
|
| 217 |
+
next_norm = norm(next_vector)
|
| 218 |
+
if next_norm < tolerance:
|
| 219 |
+
return 0.0, zeros_vector(size)
|
| 220 |
+
|
| 221 |
+
next_vector = [value / next_norm for value in next_vector]
|
| 222 |
+
delta = max(abs(a - b) for a, b in zip(vector, next_vector))
|
| 223 |
+
vector = next_vector
|
| 224 |
+
if delta < tolerance:
|
| 225 |
+
break
|
| 226 |
+
|
| 227 |
+
eigenvalue = dot(vector, matvec(matrix, vector))
|
| 228 |
+
return eigenvalue, vector
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def top_k_eigenpairs_symmetric(matrix: Matrix, k: int) -> list[tuple[float, Vector]]:
|
| 232 |
+
if np is not None and matrix:
|
| 233 |
+
values, vectors = np.linalg.eigh(np.asarray(matrix, dtype=np.float64))
|
| 234 |
+
ranked = sorted(
|
| 235 |
+
(
|
| 236 |
+
(float(values[index]), vectors[:, index].astype(float).tolist())
|
| 237 |
+
for index in range(len(values))
|
| 238 |
+
if float(values[index]) > 1e-9
|
| 239 |
+
),
|
| 240 |
+
key=lambda item: item[0],
|
| 241 |
+
reverse=True,
|
| 242 |
+
)
|
| 243 |
+
return ranked[: min(k, len(ranked))]
|
| 244 |
+
working = copy_matrix(matrix)
|
| 245 |
+
eigenpairs: list[tuple[float, Vector]] = []
|
| 246 |
+
for _ in range(min(k, len(working))):
|
| 247 |
+
eigenvalue, eigenvector = dominant_eigenpair_symmetric(working)
|
| 248 |
+
if eigenvalue <= 1e-9 or not eigenvector:
|
| 249 |
+
break
|
| 250 |
+
eigenpairs.append((eigenvalue, eigenvector))
|
| 251 |
+
deflation = scale_matrix(outer(eigenvector, eigenvector), eigenvalue)
|
| 252 |
+
working = subtract_matrices(working, deflation)
|
| 253 |
+
return eigenpairs
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def softmax(logits: Vector) -> Vector:
|
| 257 |
+
if not logits:
|
| 258 |
+
return []
|
| 259 |
+
if np is not None:
|
| 260 |
+
values = np.asarray(logits, dtype=np.float64)
|
| 261 |
+
shifted = np.exp(values - values.max())
|
| 262 |
+
total = float(shifted.sum())
|
| 263 |
+
if total == 0.0:
|
| 264 |
+
return [1.0 / len(logits) for _ in logits]
|
| 265 |
+
return (shifted / total).tolist()
|
| 266 |
+
max_logit = max(logits)
|
| 267 |
+
shifted = [math.exp(logit - max_logit) for logit in logits]
|
| 268 |
+
total = sum(shifted)
|
| 269 |
+
if total == 0.0:
|
| 270 |
+
return [1.0 / len(logits) for _ in logits]
|
| 271 |
+
return [value / total for value in shifted]
|
reframr/model.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
reframr/reasoning.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TOKENIZER_NAME = "FrameToken"
|
| 2 |
+
|
| 3 |
+
REASONING_CONTROL_TOKENS: tuple[str, ...] = (
|
| 4 |
+
"<reason>",
|
| 5 |
+
"<plan>",
|
| 6 |
+
"<reflect>",
|
| 7 |
+
"<answer>",
|
| 8 |
+
"<memory>",
|
| 9 |
+
"<retrieve>",
|
| 10 |
+
"<focus>",
|
| 11 |
+
"<verify>",
|
| 12 |
+
"<tool>",
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
REASONING_PROFILES: dict[str, tuple[str, ...]] = {
|
| 16 |
+
"none": (),
|
| 17 |
+
"deep": ("<reason>",),
|
| 18 |
+
"memory": ("<memory>", "<retrieve>", "<focus>"),
|
| 19 |
+
"tool": ("<tool>", "<reason>", "<verify>"),
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def reasoning_prefix(mode: str) -> list[str]:
|
| 24 |
+
if mode not in REASONING_PROFILES:
|
| 25 |
+
raise ValueError(f"Unknown reasoning mode: {mode}")
|
| 26 |
+
return list(REASONING_PROFILES[mode])
|
reframr/reservoir.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .linalg import Matrix, Vector, identity, invert_matrix, matmul, matvec, np, scale_matrix, transpose
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def _empty_matrix(matrix: Matrix) -> bool:
|
| 5 |
+
if np is not None and hasattr(matrix, "size"):
|
| 6 |
+
return int(matrix.size) == 0
|
| 7 |
+
return not matrix
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def ridge_regression_readout(
|
| 11 |
+
states: list[Vector],
|
| 12 |
+
targets: list[Vector],
|
| 13 |
+
*,
|
| 14 |
+
regularization: float,
|
| 15 |
+
) -> Matrix:
|
| 16 |
+
if not states or not targets:
|
| 17 |
+
raise ValueError("States and targets must be non-empty for ridge readout.")
|
| 18 |
+
if np is not None:
|
| 19 |
+
state_matrix = np.asarray(states, dtype=np.float64).T
|
| 20 |
+
target_matrix = np.asarray(targets, dtype=np.float64).T
|
| 21 |
+
gram = state_matrix @ state_matrix.T
|
| 22 |
+
regularized = gram + (regularization * np.eye(gram.shape[0], dtype=np.float64))
|
| 23 |
+
cross_covariance = target_matrix @ state_matrix.T
|
| 24 |
+
return np.linalg.solve(regularized.T, cross_covariance.T).T.tolist()
|
| 25 |
+
|
| 26 |
+
state_matrix = transpose(states)
|
| 27 |
+
target_matrix = transpose(targets)
|
| 28 |
+
gram = matmul(state_matrix, transpose(state_matrix))
|
| 29 |
+
regularized = [
|
| 30 |
+
[
|
| 31 |
+
gram[row][col] + (regularization if row == col else 0.0)
|
| 32 |
+
for col in range(len(gram[row]))
|
| 33 |
+
]
|
| 34 |
+
for row in range(len(gram))
|
| 35 |
+
]
|
| 36 |
+
inverse = invert_matrix(regularized)
|
| 37 |
+
cross_covariance = matmul(target_matrix, transpose(state_matrix))
|
| 38 |
+
return matmul(cross_covariance, inverse)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def ridge_regression_readout_from_moments(
|
| 42 |
+
gram: Matrix,
|
| 43 |
+
cross_covariance: Matrix,
|
| 44 |
+
*,
|
| 45 |
+
regularization: float,
|
| 46 |
+
) -> Matrix:
|
| 47 |
+
if _empty_matrix(gram) or _empty_matrix(cross_covariance):
|
| 48 |
+
raise ValueError("Gram and cross-covariance moments must be non-empty for ridge readout.")
|
| 49 |
+
if np is not None:
|
| 50 |
+
gram_array = np.asarray(gram, dtype=np.float64)
|
| 51 |
+
regularized = gram_array + (regularization * np.eye(gram_array.shape[0], dtype=np.float64))
|
| 52 |
+
cross_covariance_array = np.asarray(cross_covariance, dtype=np.float64)
|
| 53 |
+
return np.linalg.solve(regularized.T, cross_covariance_array.T).T
|
| 54 |
+
|
| 55 |
+
regularized = [
|
| 56 |
+
[
|
| 57 |
+
gram[row][col] + (regularization if row == col else 0.0)
|
| 58 |
+
for col in range(len(gram[row]))
|
| 59 |
+
]
|
| 60 |
+
for row in range(len(gram))
|
| 61 |
+
]
|
| 62 |
+
inverse = invert_matrix(regularized)
|
| 63 |
+
return matmul(cross_covariance, inverse)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def ridge_regression_readout_from_diagonal_moments(
|
| 67 |
+
feature_second_moment: Vector,
|
| 68 |
+
cross_covariance: Matrix,
|
| 69 |
+
*,
|
| 70 |
+
regularization: float,
|
| 71 |
+
) -> Matrix:
|
| 72 |
+
if _empty_matrix(feature_second_moment) or _empty_matrix(cross_covariance):
|
| 73 |
+
raise ValueError("Diagonal moments and cross-covariance must be non-empty for ridge readout.")
|
| 74 |
+
if np is not None:
|
| 75 |
+
denominator = np.asarray(feature_second_moment, dtype=np.float64) + regularization
|
| 76 |
+
denominator = np.where(np.abs(denominator) > 1e-12, denominator, regularization)
|
| 77 |
+
cross_covariance_array = np.asarray(cross_covariance, dtype=np.float64)
|
| 78 |
+
return cross_covariance_array / denominator[None, :]
|
| 79 |
+
|
| 80 |
+
denominator = [
|
| 81 |
+
value + regularization if abs(value + regularization) > 1e-12 else regularization
|
| 82 |
+
for value in feature_second_moment
|
| 83 |
+
]
|
| 84 |
+
return [
|
| 85 |
+
[
|
| 86 |
+
value / denominator[col]
|
| 87 |
+
for col, value in enumerate(row)
|
| 88 |
+
]
|
| 89 |
+
for row in cross_covariance
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def apply_readout(weights: Matrix, state: Vector) -> Vector:
|
| 94 |
+
return matvec(weights, state)
|
reframr/streaming.py
ADDED
|
@@ -0,0 +1,1852 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
import re
|
| 6 |
+
import site
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
from collections import Counter
|
| 10 |
+
from collections.abc import Iterable, Iterator
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
from .config import ReframrConfig
|
| 15 |
+
from .corpus import build_vocabulary_from_counts
|
| 16 |
+
from .embeddings import fit_ppmi_embedding_from_cooccurrence, fit_randomized_ppmi_embedding_from_counts
|
| 17 |
+
from .hippo import AnalyticalMemoryUnit
|
| 18 |
+
from .linalg import Matrix, Vector, norm, zeros, zeros_vector
|
| 19 |
+
from .model import ReframrModel, RUNTIME_ARRAY_DTYPE, TRANSITION_ORDERS, np
|
| 20 |
+
from .reservoir import (
|
| 21 |
+
ridge_regression_readout_from_diagonal_moments,
|
| 22 |
+
ridge_regression_readout_from_moments,
|
| 23 |
+
)
|
| 24 |
+
from .ternary import apply_ternary_mask, derive_ternary_mask_from_feature_energy
|
| 25 |
+
from .text_quality import clean_answer_text, clean_context_text, clean_training_text
|
| 26 |
+
from .tokenizer import NativeTokenizer
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from scipy import sparse as scipy_sparse
|
| 30 |
+
except (ImportError, ModuleNotFoundError, OSError):
|
| 31 |
+
scipy_sparse = None
|
| 32 |
+
|
| 33 |
+
TEXT_FIELD_PREFERENCES = (
|
| 34 |
+
"text",
|
| 35 |
+
"content",
|
| 36 |
+
"body",
|
| 37 |
+
"article",
|
| 38 |
+
"document",
|
| 39 |
+
"passage",
|
| 40 |
+
"markdown",
|
| 41 |
+
"answer",
|
| 42 |
+
"response",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
DIALOGUE_FIELD_PREFERENCES = (
|
| 46 |
+
"messages",
|
| 47 |
+
"conversation",
|
| 48 |
+
"conversations",
|
| 49 |
+
"dialogue",
|
| 50 |
+
"dialog",
|
| 51 |
+
"turns",
|
| 52 |
+
"chosen",
|
| 53 |
+
)
|
| 54 |
+
INSTRUCTION_FIELD_PAIRS = (
|
| 55 |
+
("instruction", "output"),
|
| 56 |
+
("prompt", "completion"),
|
| 57 |
+
("prompt", "response"),
|
| 58 |
+
("question", "answer"),
|
| 59 |
+
("question", "response"),
|
| 60 |
+
("query", "answer"),
|
| 61 |
+
("query", "response"),
|
| 62 |
+
)
|
| 63 |
+
TRANSCRIPT_ROLE_PATTERN = re.compile(r"(?:^|\n\s*\n)(Human|Assistant|System)\s*:\s*", re.IGNORECASE)
|
| 64 |
+
ROLE_ALIASES = {
|
| 65 |
+
"assistant": "assistant",
|
| 66 |
+
"assistant_response": "assistant",
|
| 67 |
+
"bot": "assistant",
|
| 68 |
+
"gpt": "assistant",
|
| 69 |
+
"model": "assistant",
|
| 70 |
+
"human": "user",
|
| 71 |
+
"prompter": "user",
|
| 72 |
+
"user": "user",
|
| 73 |
+
"customer": "user",
|
| 74 |
+
"system": "system",
|
| 75 |
+
}
|
| 76 |
+
ANSWER_READOUT_WEIGHT = 1.0
|
| 77 |
+
CONTEXT_READOUT_WEIGHT = 0.0
|
| 78 |
+
CONTEXT_STAT_WEIGHT = 0.02
|
| 79 |
+
PLAIN_TEXT_READOUT_WEIGHT = 0.03
|
| 80 |
+
PREFERENCE_REJECTED_TOKENIZER_WEIGHT = 0.0
|
| 81 |
+
PREFERENCE_BIAS_SCALE = 0.95
|
| 82 |
+
MAX_PREFERENCE_STATE_PAIRS = 512
|
| 83 |
+
ANSWER_START_TOKEN_WINDOW = 12
|
| 84 |
+
ANSWER_START_DECAY = 0.86
|
| 85 |
+
MAX_ANSWER_SEQUENCE_EXAMPLES = 196608
|
| 86 |
+
MAX_ANSWER_SEQUENCE_TOKENS = 192
|
| 87 |
+
HF_STREAM_MAX_RETRIES = 5
|
| 88 |
+
HF_STREAM_RETRY_BASE_DELAY_SECONDS = 0.25
|
| 89 |
+
FULL_READOUT_FEATURE_LIMIT = 2304
|
| 90 |
+
FULL_READOUT_EXAMPLE_LIMIT = 25000
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@dataclass(slots=True)
|
| 94 |
+
class CorpusPlanEntry:
|
| 95 |
+
source: str
|
| 96 |
+
name: str
|
| 97 |
+
dataset: str = ""
|
| 98 |
+
path: str = ""
|
| 99 |
+
config: str | None = None
|
| 100 |
+
split: str = "train"
|
| 101 |
+
limit: int = 0
|
| 102 |
+
weight: float = 1.0
|
| 103 |
+
text_field: str | None = None
|
| 104 |
+
min_words: int = 0
|
| 105 |
+
max_words: int = 0
|
| 106 |
+
min_alpha_ratio: float = 0.0
|
| 107 |
+
allowed_languages: tuple[str, ...] = ()
|
| 108 |
+
records: tuple[object, ...] = ()
|
| 109 |
+
streaming: bool = True
|
| 110 |
+
trust_remote_code: bool = False
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@dataclass(slots=True)
|
| 114 |
+
class StreamDocument:
|
| 115 |
+
text: str
|
| 116 |
+
weight: float
|
| 117 |
+
source: str
|
| 118 |
+
language: str = ""
|
| 119 |
+
preference_rejected_text: str = ""
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class StreamingCooccurrenceAccumulator:
|
| 123 |
+
def __init__(self, token_to_id: dict[str, int], window_size: int) -> None:
|
| 124 |
+
self.token_to_id = token_to_id
|
| 125 |
+
self.window_size = window_size
|
| 126 |
+
self.rows: dict[int, dict[int, float]] = {}
|
| 127 |
+
|
| 128 |
+
def update_tokens(self, tokens: list[str], *, weight: float) -> None:
|
| 129 |
+
token_ids = [self.token_to_id[token] for token in tokens if token in self.token_to_id]
|
| 130 |
+
for index, token_id in enumerate(token_ids):
|
| 131 |
+
for offset in range(1, self.window_size + 1):
|
| 132 |
+
other_index = index + offset
|
| 133 |
+
if other_index >= len(token_ids):
|
| 134 |
+
break
|
| 135 |
+
other_id = token_ids[other_index]
|
| 136 |
+
delta = weight * (1.0 / offset)
|
| 137 |
+
self.rows.setdefault(token_id, {})[other_id] = (
|
| 138 |
+
self.rows.setdefault(token_id, {}).get(other_id, 0.0) + delta
|
| 139 |
+
)
|
| 140 |
+
self.rows.setdefault(other_id, {})[token_id] = (
|
| 141 |
+
self.rows.setdefault(other_id, {}).get(token_id, 0.0) + delta
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def to_dense(self) -> Matrix:
|
| 145 |
+
size = len(self.token_to_id)
|
| 146 |
+
matrix = zeros(size, size)
|
| 147 |
+
for row, columns in self.rows.items():
|
| 148 |
+
for col, value in columns.items():
|
| 149 |
+
matrix[row][col] = value
|
| 150 |
+
return matrix
|
| 151 |
+
|
| 152 |
+
def to_sparse(self) -> object:
|
| 153 |
+
if scipy_sparse is None or np is None:
|
| 154 |
+
return self.to_dense()
|
| 155 |
+
rows: list[int] = []
|
| 156 |
+
cols: list[int] = []
|
| 157 |
+
data: list[float] = []
|
| 158 |
+
for row, columns in self.rows.items():
|
| 159 |
+
for col, value in columns.items():
|
| 160 |
+
rows.append(row)
|
| 161 |
+
cols.append(col)
|
| 162 |
+
data.append(value)
|
| 163 |
+
size = len(self.token_to_id)
|
| 164 |
+
return scipy_sparse.coo_matrix(
|
| 165 |
+
(
|
| 166 |
+
np.asarray(data, dtype=np.float64),
|
| 167 |
+
(np.asarray(rows, dtype=np.int64), np.asarray(cols, dtype=np.int64)),
|
| 168 |
+
),
|
| 169 |
+
shape=(size, size),
|
| 170 |
+
dtype=np.float64,
|
| 171 |
+
).tocsr()
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class TransitionAccumulator:
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
*,
|
| 178 |
+
max_contexts_per_order: int | None = None,
|
| 179 |
+
max_next_tokens: int = 0,
|
| 180 |
+
) -> None:
|
| 181 |
+
self.max_contexts_per_order = max_contexts_per_order
|
| 182 |
+
self.max_next_tokens = max_next_tokens
|
| 183 |
+
self.context_soft_limit = (
|
| 184 |
+
max_contexts_per_order * 4
|
| 185 |
+
if max_contexts_per_order is not None and max_contexts_per_order > 0
|
| 186 |
+
else None
|
| 187 |
+
)
|
| 188 |
+
self.next_token_soft_limit = max_next_tokens * 4 if max_next_tokens > 0 else None
|
| 189 |
+
self.counts: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
|
| 190 |
+
order: {} for order in sorted(TRANSITION_ORDERS)
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
def update_tokens(self, tokens: list[str], *, weight: float) -> None:
|
| 194 |
+
for order in sorted(TRANSITION_ORDERS):
|
| 195 |
+
order_counts = self.counts[order]
|
| 196 |
+
for index in range(order - 1, len(tokens) - 1):
|
| 197 |
+
key = tuple(tokens[index - order + 1 : index + 1])
|
| 198 |
+
nxt = tokens[index + 1]
|
| 199 |
+
if (
|
| 200 |
+
self.context_soft_limit is not None
|
| 201 |
+
and key not in order_counts
|
| 202 |
+
and len(order_counts) >= self.context_soft_limit
|
| 203 |
+
):
|
| 204 |
+
continue
|
| 205 |
+
bucket = order_counts.setdefault(key, {})
|
| 206 |
+
if (
|
| 207 |
+
self.next_token_soft_limit is not None
|
| 208 |
+
and nxt not in bucket
|
| 209 |
+
and len(bucket) >= self.next_token_soft_limit
|
| 210 |
+
):
|
| 211 |
+
continue
|
| 212 |
+
bucket[nxt] = bucket.get(nxt, 0.0) + weight
|
| 213 |
+
|
| 214 |
+
def finalize(
|
| 215 |
+
self,
|
| 216 |
+
*,
|
| 217 |
+
max_contexts_per_order: int | None,
|
| 218 |
+
max_next_tokens: int,
|
| 219 |
+
) -> dict[int, dict[tuple[str, ...], dict[str, float]]]:
|
| 220 |
+
probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
|
| 221 |
+
order: {} for order in sorted(TRANSITION_ORDERS)
|
| 222 |
+
}
|
| 223 |
+
for order, mapping in self.counts.items():
|
| 224 |
+
items = list(mapping.items())
|
| 225 |
+
items.sort(key=lambda item: (-sum(item[1].values()), item[0]))
|
| 226 |
+
if max_contexts_per_order is not None and max_contexts_per_order >= 0:
|
| 227 |
+
items = items[:max_contexts_per_order]
|
| 228 |
+
for key, bucket in items:
|
| 229 |
+
next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0]))
|
| 230 |
+
if max_next_tokens > 0:
|
| 231 |
+
next_items = next_items[:max_next_tokens]
|
| 232 |
+
total = sum(value for _, value in next_items)
|
| 233 |
+
if total <= 0.0:
|
| 234 |
+
continue
|
| 235 |
+
probabilities[order][key] = {
|
| 236 |
+
token: value / total
|
| 237 |
+
for token, value in next_items
|
| 238 |
+
}
|
| 239 |
+
return probabilities
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class StateReservoir:
|
| 243 |
+
def __init__(self, capacity: int | None, *, seed: int = 13) -> None:
|
| 244 |
+
self.capacity = capacity
|
| 245 |
+
self.random = random.Random(seed)
|
| 246 |
+
self.states: list[Vector] = []
|
| 247 |
+
self.labels: list[int] = []
|
| 248 |
+
self.weights: list[float] = []
|
| 249 |
+
self.seen = 0
|
| 250 |
+
self.total_weight = 0.0
|
| 251 |
+
|
| 252 |
+
def reserve_slot(self, weight: float = 1.0) -> int | None:
|
| 253 |
+
if weight <= 0.0:
|
| 254 |
+
return None
|
| 255 |
+
self.seen += 1
|
| 256 |
+
self.total_weight += weight
|
| 257 |
+
if self.capacity is None:
|
| 258 |
+
return len(self.states)
|
| 259 |
+
if self.capacity <= 0:
|
| 260 |
+
return None
|
| 261 |
+
if len(self.states) < self.capacity:
|
| 262 |
+
return len(self.states)
|
| 263 |
+
keep_probability = min(1.0, (self.capacity * weight) / max(self.total_weight, 1e-12))
|
| 264 |
+
if self.random.random() >= keep_probability:
|
| 265 |
+
return None
|
| 266 |
+
return self.random.randrange(self.capacity)
|
| 267 |
+
|
| 268 |
+
def store_reserved(
|
| 269 |
+
self,
|
| 270 |
+
slot: int,
|
| 271 |
+
state: Vector,
|
| 272 |
+
label_id: int,
|
| 273 |
+
*,
|
| 274 |
+
example_weight: float = 1.0,
|
| 275 |
+
) -> None:
|
| 276 |
+
stored_state = state.copy() if hasattr(state, "copy") else state[:]
|
| 277 |
+
if slot == len(self.states):
|
| 278 |
+
self.states.append(stored_state)
|
| 279 |
+
self.labels.append(label_id)
|
| 280 |
+
self.weights.append(example_weight)
|
| 281 |
+
elif 0 <= slot < len(self.states):
|
| 282 |
+
self.states[slot] = stored_state
|
| 283 |
+
self.labels[slot] = label_id
|
| 284 |
+
self.weights[slot] = example_weight
|
| 285 |
+
|
| 286 |
+
def consider(self, state: Vector, label_id: int, weight: float = 1.0) -> None:
|
| 287 |
+
slot = self.reserve_slot(weight=weight)
|
| 288 |
+
if slot is not None:
|
| 289 |
+
self.store_reserved(slot, state, label_id, example_weight=weight)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class SequenceReservoir:
|
| 293 |
+
def __init__(self, capacity: int | None, *, seed: int = 41) -> None:
|
| 294 |
+
self.capacity = capacity
|
| 295 |
+
self.random = random.Random(seed)
|
| 296 |
+
self.keys: list[Vector] = []
|
| 297 |
+
self.prompt_rows: list[list[int]] = []
|
| 298 |
+
self.token_rows: list[list[int]] = []
|
| 299 |
+
self.weights: list[float] = []
|
| 300 |
+
self.seen_weight = 0.0
|
| 301 |
+
|
| 302 |
+
def reserve_slot(self, *, weight: float = 1.0) -> int | None:
|
| 303 |
+
if self.capacity == 0 or weight <= 0.0:
|
| 304 |
+
return None
|
| 305 |
+
self.seen_weight += weight
|
| 306 |
+
if self.capacity is None or len(self.keys) < self.capacity:
|
| 307 |
+
return len(self.keys)
|
| 308 |
+
probability = min(1.0, (self.capacity * weight) / max(self.seen_weight, 1e-12))
|
| 309 |
+
if self.random.random() >= probability:
|
| 310 |
+
return None
|
| 311 |
+
return self.random.randrange(self.capacity)
|
| 312 |
+
|
| 313 |
+
def store_reserved(
|
| 314 |
+
self,
|
| 315 |
+
slot: int,
|
| 316 |
+
key: Vector,
|
| 317 |
+
prompt_token_ids: list[int],
|
| 318 |
+
token_ids: list[int],
|
| 319 |
+
*,
|
| 320 |
+
example_weight: float = 1.0,
|
| 321 |
+
) -> None:
|
| 322 |
+
key_copy = key.tolist() if hasattr(key, "tolist") else list(key)
|
| 323 |
+
prompt_row = prompt_token_ids[:MAX_ANSWER_SEQUENCE_TOKENS]
|
| 324 |
+
row = token_ids[:MAX_ANSWER_SEQUENCE_TOKENS]
|
| 325 |
+
if self.capacity is None or slot >= len(self.keys):
|
| 326 |
+
self.keys.append(key_copy)
|
| 327 |
+
self.prompt_rows.append(prompt_row)
|
| 328 |
+
self.token_rows.append(row)
|
| 329 |
+
self.weights.append(example_weight)
|
| 330 |
+
return
|
| 331 |
+
self.keys[slot] = key_copy
|
| 332 |
+
self.prompt_rows[slot] = prompt_row
|
| 333 |
+
self.token_rows[slot] = row
|
| 334 |
+
self.weights[slot] = example_weight
|
| 335 |
+
|
| 336 |
+
def consider(
|
| 337 |
+
self,
|
| 338 |
+
key: Vector,
|
| 339 |
+
prompt_token_ids: list[int],
|
| 340 |
+
token_ids: list[int],
|
| 341 |
+
weight: float = 1.0,
|
| 342 |
+
) -> None:
|
| 343 |
+
if not token_ids:
|
| 344 |
+
return
|
| 345 |
+
slot = self.reserve_slot(weight=weight)
|
| 346 |
+
if slot is not None:
|
| 347 |
+
self.store_reserved(slot, key, prompt_token_ids, token_ids, example_weight=weight)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def _word_count(text: str) -> int:
|
| 351 |
+
return len(text.split())
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def _alpha_ratio(text: str) -> float:
|
| 355 |
+
if not text:
|
| 356 |
+
return 0.0
|
| 357 |
+
alpha_count = sum(character.isalpha() for character in text)
|
| 358 |
+
return alpha_count / len(text)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def _row_language(row: dict[str, object]) -> str:
|
| 362 |
+
for candidate in ("lang", "language", "locale"):
|
| 363 |
+
value = row.get(candidate)
|
| 364 |
+
if isinstance(value, str) and value.strip():
|
| 365 |
+
return value.strip()
|
| 366 |
+
return ""
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def _normalize_role(raw_role: object) -> str:
|
| 370 |
+
role = str(raw_role or "").strip().casefold()
|
| 371 |
+
return ROLE_ALIASES.get(role, role)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _message_content(message: dict[str, object]) -> str:
|
| 375 |
+
for field in ("content", "value", "text", "message"):
|
| 376 |
+
value = message.get(field)
|
| 377 |
+
if isinstance(value, str) and value.strip():
|
| 378 |
+
return clean_training_text(value)
|
| 379 |
+
return ""
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def _message_role(message: dict[str, object]) -> str:
|
| 383 |
+
for field in ("role", "from", "speaker", "author"):
|
| 384 |
+
value = message.get(field)
|
| 385 |
+
if value is not None:
|
| 386 |
+
normalized = _normalize_role(value)
|
| 387 |
+
if normalized:
|
| 388 |
+
return normalized
|
| 389 |
+
return ""
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def _parse_dialogue_messages(raw_messages: object) -> list[dict[str, str]]:
|
| 393 |
+
if not isinstance(raw_messages, list):
|
| 394 |
+
return []
|
| 395 |
+
|
| 396 |
+
parsed: list[dict[str, str]] = []
|
| 397 |
+
for message in raw_messages:
|
| 398 |
+
if not isinstance(message, dict):
|
| 399 |
+
continue
|
| 400 |
+
role = _message_role(message)
|
| 401 |
+
content = _message_content(message)
|
| 402 |
+
if role not in {"system", "user", "assistant"} or not content:
|
| 403 |
+
continue
|
| 404 |
+
parsed.append({"role": role, "content": content})
|
| 405 |
+
return parsed
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def _parse_transcript_messages(raw_text: object) -> list[dict[str, str]]:
|
| 409 |
+
if not isinstance(raw_text, str):
|
| 410 |
+
return []
|
| 411 |
+
|
| 412 |
+
text = raw_text.strip()
|
| 413 |
+
if not text:
|
| 414 |
+
return []
|
| 415 |
+
|
| 416 |
+
matches = list(TRANSCRIPT_ROLE_PATTERN.finditer(text))
|
| 417 |
+
if not matches:
|
| 418 |
+
return []
|
| 419 |
+
|
| 420 |
+
parsed: list[dict[str, str]] = []
|
| 421 |
+
for index, match in enumerate(matches):
|
| 422 |
+
role = _normalize_role(match.group(1))
|
| 423 |
+
start = match.end()
|
| 424 |
+
end = matches[index + 1].start() if index + 1 < len(matches) else len(text)
|
| 425 |
+
content = clean_training_text(text[start:end].strip())
|
| 426 |
+
if role in {"system", "user", "assistant"} and content:
|
| 427 |
+
parsed.append({"role": role, "content": content})
|
| 428 |
+
return parsed
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def _render_prompt(messages: list[dict[str, str]]) -> str:
|
| 432 |
+
parts = []
|
| 433 |
+
for message in messages:
|
| 434 |
+
content = clean_context_text(message["content"])
|
| 435 |
+
if content:
|
| 436 |
+
parts.append(content)
|
| 437 |
+
return "\n".join(parts).strip()
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def _last_user_prompt_before(messages: list[dict[str, str]], end_index: int) -> str:
|
| 441 |
+
for message in reversed(messages[:end_index]):
|
| 442 |
+
if message["role"] == "user":
|
| 443 |
+
return clean_context_text(message["content"])
|
| 444 |
+
return _render_prompt(messages[:end_index])
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def _compose_training_text(context: object, answer: object) -> str:
|
| 448 |
+
prompt_text = clean_context_text(_flatten_value(context))
|
| 449 |
+
answer_text = clean_answer_text(_flatten_value(answer))
|
| 450 |
+
if prompt_text and answer_text:
|
| 451 |
+
return f"<reason> {prompt_text} <answer> {answer_text}".strip()
|
| 452 |
+
return clean_training_text(answer_text or prompt_text)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def _compose_from_messages(messages: list[dict[str, str]]) -> str:
|
| 456 |
+
assistant_index = None
|
| 457 |
+
for index in range(len(messages) - 1, -1, -1):
|
| 458 |
+
if messages[index]["role"] == "assistant":
|
| 459 |
+
assistant_index = index
|
| 460 |
+
break
|
| 461 |
+
if assistant_index is not None:
|
| 462 |
+
prompt = _last_user_prompt_before(messages, assistant_index)
|
| 463 |
+
answer = clean_answer_text(messages[assistant_index]["content"])
|
| 464 |
+
if prompt and answer:
|
| 465 |
+
return f"<reason> {prompt} <answer> {answer}".strip()
|
| 466 |
+
return "\n".join(
|
| 467 |
+
message["content"]
|
| 468 |
+
for message in messages
|
| 469 |
+
if message.get("content")
|
| 470 |
+
).strip()
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def _flatten_message_list(messages: object) -> str:
|
| 474 |
+
parsed = _parse_dialogue_messages(messages)
|
| 475 |
+
if parsed:
|
| 476 |
+
return _compose_from_messages(parsed)
|
| 477 |
+
if not isinstance(messages, list):
|
| 478 |
+
return ""
|
| 479 |
+
parts: list[str] = []
|
| 480 |
+
for message in messages:
|
| 481 |
+
if not isinstance(message, dict):
|
| 482 |
+
continue
|
| 483 |
+
content = str(
|
| 484 |
+
message.get("content", message.get("value", message.get("text", "")))
|
| 485 |
+
).strip()
|
| 486 |
+
if not content:
|
| 487 |
+
continue
|
| 488 |
+
parts.append(clean_training_text(content))
|
| 489 |
+
return "\n".join(parts).strip()
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def _flatten_value(value: object) -> str:
|
| 493 |
+
if isinstance(value, str):
|
| 494 |
+
parsed = _parse_transcript_messages(value)
|
| 495 |
+
if parsed:
|
| 496 |
+
return _compose_from_messages(parsed)
|
| 497 |
+
return clean_training_text(value.strip())
|
| 498 |
+
if isinstance(value, list):
|
| 499 |
+
return _flatten_message_list(value)
|
| 500 |
+
if isinstance(value, dict):
|
| 501 |
+
for field in ("messages", "conversation", "conversations", "dialogue", "turns"):
|
| 502 |
+
nested_messages = value.get(field)
|
| 503 |
+
text = _flatten_message_list(nested_messages)
|
| 504 |
+
if text:
|
| 505 |
+
return text
|
| 506 |
+
for field in ("text", "content", "value", "message"):
|
| 507 |
+
nested = value.get(field)
|
| 508 |
+
if isinstance(nested, str) and nested.strip():
|
| 509 |
+
return _flatten_value(nested)
|
| 510 |
+
return ""
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def _safe_flag(value: object) -> bool | None:
|
| 514 |
+
if isinstance(value, bool):
|
| 515 |
+
return value
|
| 516 |
+
if isinstance(value, str):
|
| 517 |
+
normalized = value.strip().casefold()
|
| 518 |
+
if normalized in {"true", "1", "yes", "safe"}:
|
| 519 |
+
return True
|
| 520 |
+
if normalized in {"false", "0", "no", "unsafe"}:
|
| 521 |
+
return False
|
| 522 |
+
return None
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def _selected_response_fields(row: dict[str, object]) -> tuple[str, str]:
|
| 526 |
+
if "response_0" not in row or "response_1" not in row:
|
| 527 |
+
return "", ""
|
| 528 |
+
safe_0 = _safe_flag(row.get("is_response_0_safe"))
|
| 529 |
+
safe_1 = _safe_flag(row.get("is_response_1_safe"))
|
| 530 |
+
if safe_0 is not None and safe_1 is not None:
|
| 531 |
+
if safe_0 and not safe_1:
|
| 532 |
+
return "response_0", "response_1"
|
| 533 |
+
if safe_1 and not safe_0:
|
| 534 |
+
return "response_1", "response_0"
|
| 535 |
+
if safe_0 and safe_1:
|
| 536 |
+
return "response_0", ""
|
| 537 |
+
return "", ""
|
| 538 |
+
for selector in ("safer_response_id", "better_response_id"):
|
| 539 |
+
raw_value = row.get(selector)
|
| 540 |
+
try:
|
| 541 |
+
preferred = int(raw_value)
|
| 542 |
+
except (TypeError, ValueError):
|
| 543 |
+
continue
|
| 544 |
+
chosen = "response_1" if preferred == 1 else "response_0"
|
| 545 |
+
rejected = "response_0" if chosen == "response_1" else "response_1"
|
| 546 |
+
return chosen, rejected
|
| 547 |
+
return "response_0", "response_1"
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def _extract_preference_pair(row: dict[str, object]) -> tuple[str, str]:
|
| 551 |
+
if "chosen" in row and "rejected" in row:
|
| 552 |
+
chosen_text = clean_training_text(_flatten_value(row.get("chosen")))
|
| 553 |
+
rejected_text = clean_training_text(_flatten_value(row.get("rejected")))
|
| 554 |
+
if chosen_text and rejected_text:
|
| 555 |
+
return chosen_text, rejected_text
|
| 556 |
+
if "response_0" in row and "response_1" in row:
|
| 557 |
+
preferred_field, rejected_field = _selected_response_fields(row)
|
| 558 |
+
if not preferred_field or not rejected_field:
|
| 559 |
+
return "", ""
|
| 560 |
+
prompt = row.get("prompt", row.get("question", row.get("query", "")))
|
| 561 |
+
if prompt:
|
| 562 |
+
chosen_text = _compose_training_text(prompt, row.get(preferred_field))
|
| 563 |
+
rejected_text = _compose_training_text(prompt, row.get(rejected_field))
|
| 564 |
+
if chosen_text and rejected_text:
|
| 565 |
+
return clean_training_text(chosen_text), clean_training_text(rejected_text)
|
| 566 |
+
chosen_text = clean_training_text(_flatten_value(row.get(preferred_field)))
|
| 567 |
+
rejected_text = clean_training_text(_flatten_value(row.get(rejected_field)))
|
| 568 |
+
if chosen_text and rejected_text:
|
| 569 |
+
return chosen_text, rejected_text
|
| 570 |
+
return "", ""
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def _extract_preference_value(row: dict[str, object]) -> str:
|
| 574 |
+
chosen_text, _ = _extract_preference_pair(row)
|
| 575 |
+
return chosen_text
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def _extract_row_text(row: dict[str, object], text_field: str | None) -> str:
|
| 579 |
+
if "context" in row and "answer" in row:
|
| 580 |
+
context = clean_context_text(_flatten_value(row.get("context")))
|
| 581 |
+
answer = clean_answer_text(_flatten_value(row.get("answer")))
|
| 582 |
+
if context and answer:
|
| 583 |
+
return f"<reason> {context} <answer> {answer}".strip()
|
| 584 |
+
|
| 585 |
+
if "response_0" in row and "response_1" in row:
|
| 586 |
+
preferred_field, _ = _selected_response_fields(row)
|
| 587 |
+
prompt = row.get("prompt", row.get("question", row.get("query", "")))
|
| 588 |
+
if preferred_field and prompt:
|
| 589 |
+
text = _compose_training_text(prompt, row.get(preferred_field))
|
| 590 |
+
if text:
|
| 591 |
+
return text
|
| 592 |
+
|
| 593 |
+
for prompt_field, answer_field in INSTRUCTION_FIELD_PAIRS:
|
| 594 |
+
if prompt_field in row and answer_field in row:
|
| 595 |
+
text = _compose_training_text(row.get(prompt_field), row.get(answer_field))
|
| 596 |
+
if text:
|
| 597 |
+
return text
|
| 598 |
+
|
| 599 |
+
if text_field is not None:
|
| 600 |
+
return clean_training_text(_flatten_value(row.get(text_field)))
|
| 601 |
+
|
| 602 |
+
preferred = _extract_preference_value(row)
|
| 603 |
+
if preferred:
|
| 604 |
+
return clean_training_text(preferred)
|
| 605 |
+
|
| 606 |
+
for field in TEXT_FIELD_PREFERENCES:
|
| 607 |
+
text = _flatten_value(row.get(field))
|
| 608 |
+
if text:
|
| 609 |
+
return clean_training_text(text)
|
| 610 |
+
for field in DIALOGUE_FIELD_PREFERENCES:
|
| 611 |
+
text = _flatten_value(row.get(field))
|
| 612 |
+
if text:
|
| 613 |
+
return clean_training_text(text)
|
| 614 |
+
return ""
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def _passes_text_quality(text: str, language: str, entry: CorpusPlanEntry) -> bool:
|
| 618 |
+
if not text:
|
| 619 |
+
return False
|
| 620 |
+
word_count = _word_count(text)
|
| 621 |
+
if entry.min_words > 0 and word_count < entry.min_words:
|
| 622 |
+
return False
|
| 623 |
+
if entry.max_words > 0 and word_count > entry.max_words:
|
| 624 |
+
return False
|
| 625 |
+
if entry.min_alpha_ratio > 0.0 and _alpha_ratio(text) < entry.min_alpha_ratio:
|
| 626 |
+
return False
|
| 627 |
+
if entry.allowed_languages:
|
| 628 |
+
if not language or language.casefold() not in entry.allowed_languages:
|
| 629 |
+
return False
|
| 630 |
+
return True
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def load_corpus_plan(source: str | Path) -> list[CorpusPlanEntry]:
|
| 634 |
+
payload = json.loads(Path(source).read_text(encoding="utf-8-sig"))
|
| 635 |
+
raw_entries = payload.get("sources", payload.get("datasets", []))
|
| 636 |
+
if not isinstance(raw_entries, list) or not raw_entries:
|
| 637 |
+
raise ValueError("Corpus plan must define a non-empty 'sources' list.")
|
| 638 |
+
|
| 639 |
+
entries: list[CorpusPlanEntry] = []
|
| 640 |
+
for index, raw_entry in enumerate(raw_entries, start=1):
|
| 641 |
+
if not isinstance(raw_entry, dict):
|
| 642 |
+
raise ValueError("Each corpus plan entry must be an object.")
|
| 643 |
+
source = str(raw_entry.get("source", "hf")).strip() or "hf"
|
| 644 |
+
name = str(
|
| 645 |
+
raw_entry.get("name", raw_entry.get("dataset", f"source-{index}"))
|
| 646 |
+
).strip() or f"source-{index}"
|
| 647 |
+
raw_languages = raw_entry.get("allowed_languages", [])
|
| 648 |
+
allowed_languages = tuple(
|
| 649 |
+
str(value).strip().casefold()
|
| 650 |
+
for value in raw_languages
|
| 651 |
+
if str(value).strip()
|
| 652 |
+
) if isinstance(raw_languages, list) else ()
|
| 653 |
+
raw_records = raw_entry.get("records", raw_entry.get("texts", []))
|
| 654 |
+
if source == "inline" and not isinstance(raw_records, list):
|
| 655 |
+
raise ValueError("Inline corpus plan entries must provide a records/texts list.")
|
| 656 |
+
entries.append(
|
| 657 |
+
CorpusPlanEntry(
|
| 658 |
+
source=source,
|
| 659 |
+
name=name,
|
| 660 |
+
dataset=str(raw_entry.get("dataset", "")),
|
| 661 |
+
path=str(raw_entry.get("path", raw_entry.get("file", ""))),
|
| 662 |
+
config=(
|
| 663 |
+
str(raw_entry["config"])
|
| 664 |
+
if raw_entry.get("config") is not None
|
| 665 |
+
else None
|
| 666 |
+
),
|
| 667 |
+
split=str(raw_entry.get("split", "train")),
|
| 668 |
+
limit=int(raw_entry.get("limit", 0)),
|
| 669 |
+
weight=float(raw_entry.get("weight", 1.0)),
|
| 670 |
+
text_field=(
|
| 671 |
+
str(raw_entry["text_field"])
|
| 672 |
+
if raw_entry.get("text_field") is not None
|
| 673 |
+
else None
|
| 674 |
+
),
|
| 675 |
+
min_words=int(raw_entry.get("min_words", 0)),
|
| 676 |
+
max_words=int(raw_entry.get("max_words", 0)),
|
| 677 |
+
min_alpha_ratio=float(raw_entry.get("min_alpha_ratio", 0.0)),
|
| 678 |
+
allowed_languages=allowed_languages,
|
| 679 |
+
records=tuple(raw_records) if isinstance(raw_records, list) else (),
|
| 680 |
+
streaming=bool(raw_entry.get("streaming", True)),
|
| 681 |
+
trust_remote_code=bool(raw_entry.get("trust_remote_code", False)),
|
| 682 |
+
)
|
| 683 |
+
)
|
| 684 |
+
return entries
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def _iter_hf_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]:
|
| 688 |
+
try:
|
| 689 |
+
from datasets import load_dataset
|
| 690 |
+
except ModuleNotFoundError:
|
| 691 |
+
user_site = site.getusersitepackages()
|
| 692 |
+
if user_site and user_site not in sys.path:
|
| 693 |
+
sys.path.append(user_site)
|
| 694 |
+
from datasets import load_dataset
|
| 695 |
+
|
| 696 |
+
dataset_kwargs: dict[str, object] = {
|
| 697 |
+
"split": entry.split,
|
| 698 |
+
"streaming": entry.streaming,
|
| 699 |
+
}
|
| 700 |
+
if entry.config:
|
| 701 |
+
dataset_kwargs["name"] = entry.config
|
| 702 |
+
if entry.trust_remote_code:
|
| 703 |
+
dataset_kwargs["trust_remote_code"] = True
|
| 704 |
+
|
| 705 |
+
for row in load_dataset(entry.dataset, **dataset_kwargs):
|
| 706 |
+
yield dict(row)
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
def _iter_file_rows(entry: CorpusPlanEntry) -> Iterator[dict[str, object]]:
|
| 710 |
+
raw_path = entry.path or entry.dataset
|
| 711 |
+
if not raw_path:
|
| 712 |
+
raise ValueError("File corpus plan entries must provide a path.")
|
| 713 |
+
path = Path(raw_path)
|
| 714 |
+
suffix = path.suffix.lower()
|
| 715 |
+
if suffix == ".jsonl":
|
| 716 |
+
with path.open("r", encoding="utf-8") as handle:
|
| 717 |
+
for line in handle:
|
| 718 |
+
if line.strip():
|
| 719 |
+
row = json.loads(line)
|
| 720 |
+
yield row if isinstance(row, dict) else {"text": str(row)}
|
| 721 |
+
return
|
| 722 |
+
if suffix == ".json":
|
| 723 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 724 |
+
if isinstance(payload, list):
|
| 725 |
+
for row in payload:
|
| 726 |
+
yield row if isinstance(row, dict) else {"text": str(row)}
|
| 727 |
+
return
|
| 728 |
+
if isinstance(payload, dict):
|
| 729 |
+
rows = payload.get("records", payload.get("texts"))
|
| 730 |
+
if isinstance(rows, list):
|
| 731 |
+
for row in rows:
|
| 732 |
+
yield row if isinstance(row, dict) else {"text": str(row)}
|
| 733 |
+
return
|
| 734 |
+
yield payload
|
| 735 |
+
return
|
| 736 |
+
if suffix in {".txt", ".md", ".text"}:
|
| 737 |
+
yield {"text": path.read_text(encoding="utf-8")}
|
| 738 |
+
return
|
| 739 |
+
raise ValueError(f"Unsupported file corpus source: {path}")
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def iter_corpus_plan_documents(plan: Iterable[CorpusPlanEntry]) -> Iterator[StreamDocument]:
|
| 743 |
+
for entry in plan:
|
| 744 |
+
accepted = 0
|
| 745 |
+
attempts = 0
|
| 746 |
+
while True:
|
| 747 |
+
accepted_seen_this_attempt = 0
|
| 748 |
+
try:
|
| 749 |
+
if entry.source == "inline":
|
| 750 |
+
row_iterator = (
|
| 751 |
+
item if isinstance(item, dict) else {"text": str(item)}
|
| 752 |
+
for item in entry.records
|
| 753 |
+
)
|
| 754 |
+
elif entry.source == "hf":
|
| 755 |
+
row_iterator = _iter_hf_rows(entry)
|
| 756 |
+
elif entry.source == "file":
|
| 757 |
+
row_iterator = _iter_file_rows(entry)
|
| 758 |
+
else:
|
| 759 |
+
raise ValueError(f"Unsupported corpus plan source: {entry.source}")
|
| 760 |
+
|
| 761 |
+
for row in row_iterator:
|
| 762 |
+
language = _row_language(row)
|
| 763 |
+
_, rejected_text = _extract_preference_pair(row)
|
| 764 |
+
text = clean_training_text(_extract_row_text(row, entry.text_field))
|
| 765 |
+
if not _passes_text_quality(text, language, entry):
|
| 766 |
+
continue
|
| 767 |
+
accepted_seen_this_attempt += 1
|
| 768 |
+
if accepted_seen_this_attempt <= accepted:
|
| 769 |
+
continue
|
| 770 |
+
yield StreamDocument(
|
| 771 |
+
text=text,
|
| 772 |
+
weight=entry.weight,
|
| 773 |
+
source=entry.name,
|
| 774 |
+
language=language,
|
| 775 |
+
preference_rejected_text=rejected_text,
|
| 776 |
+
)
|
| 777 |
+
accepted += 1
|
| 778 |
+
if entry.limit > 0 and accepted >= entry.limit:
|
| 779 |
+
break
|
| 780 |
+
break
|
| 781 |
+
except Exception as exc:
|
| 782 |
+
if entry.source != "hf":
|
| 783 |
+
raise
|
| 784 |
+
if attempts >= HF_STREAM_MAX_RETRIES:
|
| 785 |
+
print(
|
| 786 |
+
f"[source] {entry.name} skipped after {attempts} retries; "
|
| 787 |
+
f"accepted {accepted} documents before final error: {exc}"
|
| 788 |
+
)
|
| 789 |
+
break
|
| 790 |
+
attempts += 1
|
| 791 |
+
delay = min(
|
| 792 |
+
15.0,
|
| 793 |
+
HF_STREAM_RETRY_BASE_DELAY_SECONDS * (2 ** (attempts - 1)),
|
| 794 |
+
)
|
| 795 |
+
print(
|
| 796 |
+
f"[source] {entry.name} stream interrupted after {accepted} accepted "
|
| 797 |
+
f"documents; retry {attempts}/{HF_STREAM_MAX_RETRIES} in {delay:.2f}s: {exc}"
|
| 798 |
+
)
|
| 799 |
+
time.sleep(delay)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
def _log_progress(label: str, processed: int, log_every: int) -> None:
|
| 803 |
+
if log_every > 0 and processed % log_every == 0:
|
| 804 |
+
print(f"[{label}] processed {processed} documents")
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def _answer_boundary(tokens: list[str]) -> int | None:
|
| 808 |
+
try:
|
| 809 |
+
return tokens.index("<answer>")
|
| 810 |
+
except ValueError:
|
| 811 |
+
return None
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
def _weighted_text_parts_for_statistics(text: str, document_weight: float) -> list[tuple[str, float]]:
|
| 815 |
+
if "<answer>" not in text:
|
| 816 |
+
return [(text, document_weight)]
|
| 817 |
+
context, answer = text.split("<answer>", 1)
|
| 818 |
+
context = clean_context_text(context.replace("<reason>", " "))
|
| 819 |
+
answer = clean_answer_text(answer)
|
| 820 |
+
parts: list[tuple[str, float]] = []
|
| 821 |
+
if context:
|
| 822 |
+
parts.append((context, document_weight * CONTEXT_STAT_WEIGHT))
|
| 823 |
+
if answer:
|
| 824 |
+
parts.append((answer, document_weight * ANSWER_READOUT_WEIGHT))
|
| 825 |
+
return parts or [(text, document_weight)]
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
def _weighted_token_sequences_for_statistics(
|
| 829 |
+
tokens: list[str],
|
| 830 |
+
tokenizer: NativeTokenizer,
|
| 831 |
+
document_weight: float,
|
| 832 |
+
) -> list[tuple[list[str], float]]:
|
| 833 |
+
answer_index = _answer_boundary(tokens)
|
| 834 |
+
if answer_index is None:
|
| 835 |
+
sequence = [token for token in tokens if token not in tokenizer.special_tokens]
|
| 836 |
+
return [(sequence, document_weight)] if sequence else []
|
| 837 |
+
context_tokens = [
|
| 838 |
+
token for token in tokens[:answer_index] if token not in tokenizer.special_tokens
|
| 839 |
+
]
|
| 840 |
+
answer_tokens = [
|
| 841 |
+
token for token in tokens[answer_index + 1 :] if token not in tokenizer.special_tokens
|
| 842 |
+
]
|
| 843 |
+
sequences: list[tuple[list[str], float]] = []
|
| 844 |
+
if context_tokens:
|
| 845 |
+
sequences.append((context_tokens, document_weight * CONTEXT_STAT_WEIGHT))
|
| 846 |
+
if answer_tokens:
|
| 847 |
+
sequences.append((answer_tokens, document_weight * ANSWER_READOUT_WEIGHT))
|
| 848 |
+
return sequences
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
def _readout_weight_for_target(
|
| 852 |
+
answer_index: int | None,
|
| 853 |
+
target_index: int,
|
| 854 |
+
document_weight: float,
|
| 855 |
+
) -> float:
|
| 856 |
+
if answer_index is None:
|
| 857 |
+
return document_weight * PLAIN_TEXT_READOUT_WEIGHT
|
| 858 |
+
if target_index <= answer_index:
|
| 859 |
+
return document_weight * CONTEXT_READOUT_WEIGHT
|
| 860 |
+
return document_weight * ANSWER_READOUT_WEIGHT
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def _answer_payload_tokens(tokens: list[str], tokenizer: NativeTokenizer) -> list[str]:
|
| 864 |
+
answer_index = _answer_boundary(tokens)
|
| 865 |
+
payload = tokens[answer_index + 1 :] if answer_index is not None else tokens
|
| 866 |
+
return [token for token in payload if token not in tokenizer.special_tokens]
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
def _standardized_preference_bias(values: object, active_mask: object | None = None) -> list[float]:
|
| 870 |
+
if np is not None:
|
| 871 |
+
bias = np.asarray(values, dtype=np.float64)
|
| 872 |
+
if bias.size == 0:
|
| 873 |
+
return []
|
| 874 |
+
active = (
|
| 875 |
+
np.asarray(active_mask, dtype=bool)
|
| 876 |
+
if active_mask is not None
|
| 877 |
+
else np.ones(bias.shape, dtype=bool)
|
| 878 |
+
)
|
| 879 |
+
if not np.any(active):
|
| 880 |
+
return [0.0 for _ in range(int(bias.size))]
|
| 881 |
+
active_values = bias[active]
|
| 882 |
+
spread = float(active_values.std())
|
| 883 |
+
if spread <= 1e-12:
|
| 884 |
+
return [0.0 for _ in range(int(bias.size))]
|
| 885 |
+
standardized = np.zeros_like(bias, dtype=np.float64)
|
| 886 |
+
standardized[active] = (
|
| 887 |
+
(active_values - float(active_values.mean())) / spread
|
| 888 |
+
) * PREFERENCE_BIAS_SCALE
|
| 889 |
+
return np.clip(standardized, -2.5, 2.5).astype(float).tolist()
|
| 890 |
+
raw_values = [float(value) for value in values]
|
| 891 |
+
if not raw_values:
|
| 892 |
+
return []
|
| 893 |
+
average = sum(raw_values) / len(raw_values)
|
| 894 |
+
variance = sum((value - average) * (value - average) for value in raw_values) / len(raw_values)
|
| 895 |
+
spread = variance**0.5
|
| 896 |
+
if spread <= 1e-12:
|
| 897 |
+
return [0.0 for _ in raw_values]
|
| 898 |
+
active_indices = (
|
| 899 |
+
[
|
| 900 |
+
index
|
| 901 |
+
for index, active in enumerate(active_mask)
|
| 902 |
+
if active
|
| 903 |
+
]
|
| 904 |
+
if active_mask is not None
|
| 905 |
+
else list(range(len(raw_values)))
|
| 906 |
+
)
|
| 907 |
+
if not active_indices:
|
| 908 |
+
return [0.0 for _ in raw_values]
|
| 909 |
+
active_values = [raw_values[index] for index in active_indices]
|
| 910 |
+
average = mean(active_values)
|
| 911 |
+
spread = (mean([(value - average) * (value - average) for value in active_values])) ** 0.5
|
| 912 |
+
if spread <= 1e-12:
|
| 913 |
+
return [0.0 for _ in raw_values]
|
| 914 |
+
standardized = [0.0 for _ in raw_values]
|
| 915 |
+
for index in active_indices:
|
| 916 |
+
standardized[index] = max(
|
| 917 |
+
-2.5,
|
| 918 |
+
min(2.5, ((raw_values[index] - average) / spread) * PREFERENCE_BIAS_SCALE),
|
| 919 |
+
)
|
| 920 |
+
return standardized
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
def _candidate_preference_bias_from_state_vector(
|
| 924 |
+
model: ReframrModel,
|
| 925 |
+
preference_state: object,
|
| 926 |
+
) -> object:
|
| 927 |
+
if np is None:
|
| 928 |
+
return None
|
| 929 |
+
assert model.embedding_model is not None
|
| 930 |
+
assert model.memory_units is not None
|
| 931 |
+
assert model.ternary_mask is not None
|
| 932 |
+
|
| 933 |
+
embeddings = np.asarray(model.embedding_model.embeddings, dtype=np.float64)
|
| 934 |
+
if embeddings.size == 0:
|
| 935 |
+
return np.zeros(0, dtype=np.float64)
|
| 936 |
+
state_vector = np.asarray(preference_state, dtype=np.float64)
|
| 937 |
+
mask = np.asarray(model.ternary_mask, dtype=np.float64) * float(model.ternary_scale)
|
| 938 |
+
if state_vector.shape[0] != mask.shape[0]:
|
| 939 |
+
return np.zeros(embeddings.shape[0], dtype=np.float64)
|
| 940 |
+
|
| 941 |
+
state_indices = np.arange(model.config.state_dim, dtype=np.int64)
|
| 942 |
+
drive = (
|
| 943 |
+
embeddings[:, state_indices % model.config.embedding_dim]
|
| 944 |
+
+ (0.5 * embeddings[:, (3 * state_indices + 1) % model.config.embedding_dim])
|
| 945 |
+
- (0.25 * embeddings[:, (5 * state_indices + 2) % model.config.embedding_dim])
|
| 946 |
+
)
|
| 947 |
+
scores = np.zeros(embeddings.shape[0], dtype=np.float64)
|
| 948 |
+
offset = 0
|
| 949 |
+
for unit in model.memory_units:
|
| 950 |
+
hidden_end = offset + model.config.state_dim
|
| 951 |
+
trace_end = hidden_end + model.config.embedding_dim
|
| 952 |
+
hidden_pref = state_vector[offset:hidden_end] * mask[offset:hidden_end]
|
| 953 |
+
trace_pref = state_vector[hidden_end:trace_end] * mask[hidden_end:trace_end]
|
| 954 |
+
hidden_delta_axis = np.asarray(unit.input_projection, dtype=np.float64) * hidden_pref
|
| 955 |
+
trace_gain = 1.0 - (1.0 / (1.0 + unit.timescale))
|
| 956 |
+
scores += drive @ hidden_delta_axis
|
| 957 |
+
scores += embeddings @ (trace_gain * trace_pref)
|
| 958 |
+
offset = trace_end
|
| 959 |
+
return scores
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
def _derive_preference_bias_from_pairs(
|
| 963 |
+
model: ReframrModel,
|
| 964 |
+
preference_token_pairs: list[tuple[list[str], list[str], float]],
|
| 965 |
+
tokenizer: NativeTokenizer,
|
| 966 |
+
) -> tuple[list[float], int]:
|
| 967 |
+
assert model.embedding_model is not None
|
| 968 |
+
vocab_size = len(model.embedding_model.id_to_token)
|
| 969 |
+
if not preference_token_pairs:
|
| 970 |
+
return [0.0 for _ in range(vocab_size)], 0
|
| 971 |
+
|
| 972 |
+
if np is not None:
|
| 973 |
+
token_bias = np.zeros(vocab_size, dtype=np.float64)
|
| 974 |
+
active_token_mask = np.zeros(vocab_size, dtype=bool)
|
| 975 |
+
state_delta = np.zeros(model._combined_state_width(), dtype=np.float64)
|
| 976 |
+
else:
|
| 977 |
+
token_bias = [0.0 for _ in range(vocab_size)]
|
| 978 |
+
active_token_ids: set[int] = set()
|
| 979 |
+
state_delta = [0.0 for _ in range(model._combined_state_width())]
|
| 980 |
+
pair_weight_total = 0.0
|
| 981 |
+
state_pair_count = 0
|
| 982 |
+
state_stride = max(
|
| 983 |
+
1,
|
| 984 |
+
(len(preference_token_pairs) + MAX_PREFERENCE_STATE_PAIRS - 1)
|
| 985 |
+
// MAX_PREFERENCE_STATE_PAIRS,
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
for pair_index, (chosen_tokens, rejected_tokens, pair_weight) in enumerate(preference_token_pairs):
|
| 989 |
+
chosen_answer = _answer_payload_tokens(chosen_tokens, tokenizer)
|
| 990 |
+
rejected_answer = _answer_payload_tokens(rejected_tokens, tokenizer)
|
| 991 |
+
if chosen_answer:
|
| 992 |
+
delta = pair_weight / max(1, len(chosen_answer))
|
| 993 |
+
for token in chosen_answer:
|
| 994 |
+
token_id = model.embedding_model.token_to_id.get(token)
|
| 995 |
+
if token_id is not None:
|
| 996 |
+
token_bias[token_id] += delta
|
| 997 |
+
if np is not None:
|
| 998 |
+
active_token_mask[token_id] = True
|
| 999 |
+
else:
|
| 1000 |
+
active_token_ids.add(token_id)
|
| 1001 |
+
if rejected_answer:
|
| 1002 |
+
delta = pair_weight / max(1, len(rejected_answer))
|
| 1003 |
+
for token in rejected_answer:
|
| 1004 |
+
token_id = model.embedding_model.token_to_id.get(token)
|
| 1005 |
+
if token_id is not None:
|
| 1006 |
+
token_bias[token_id] -= delta
|
| 1007 |
+
if np is not None:
|
| 1008 |
+
active_token_mask[token_id] = True
|
| 1009 |
+
else:
|
| 1010 |
+
active_token_ids.add(token_id)
|
| 1011 |
+
|
| 1012 |
+
if pair_index % state_stride != 0 or state_pair_count >= MAX_PREFERENCE_STATE_PAIRS:
|
| 1013 |
+
continue
|
| 1014 |
+
chosen_state = model._masked_decode_state(model._build_decode_state(chosen_tokens))
|
| 1015 |
+
rejected_state = model._masked_decode_state(model._build_decode_state(rejected_tokens))
|
| 1016 |
+
if len(chosen_state) != len(rejected_state):
|
| 1017 |
+
continue
|
| 1018 |
+
pair_weight_total += pair_weight
|
| 1019 |
+
state_pair_count += 1
|
| 1020 |
+
if np is not None:
|
| 1021 |
+
state_delta += pair_weight * (
|
| 1022 |
+
np.asarray(chosen_state, dtype=np.float64)
|
| 1023 |
+
- np.asarray(rejected_state, dtype=np.float64)
|
| 1024 |
+
)
|
| 1025 |
+
else:
|
| 1026 |
+
for index, (chosen_value, rejected_value) in enumerate(zip(chosen_state, rejected_state)):
|
| 1027 |
+
state_delta[index] += pair_weight * (chosen_value - rejected_value)
|
| 1028 |
+
|
| 1029 |
+
if pair_weight_total > 0.0:
|
| 1030 |
+
if np is not None:
|
| 1031 |
+
state_delta = state_delta / pair_weight_total
|
| 1032 |
+
candidate_bias = _candidate_preference_bias_from_state_vector(model, state_delta)
|
| 1033 |
+
if candidate_bias is not None:
|
| 1034 |
+
token_bias[active_token_mask] = (
|
| 1035 |
+
token_bias[active_token_mask] + candidate_bias[active_token_mask]
|
| 1036 |
+
)
|
| 1037 |
+
else:
|
| 1038 |
+
state_delta = [value / pair_weight_total for value in state_delta]
|
| 1039 |
+
if np is not None:
|
| 1040 |
+
return _standardized_preference_bias(token_bias, active_token_mask), state_pair_count
|
| 1041 |
+
active_mask = [index in active_token_ids for index in range(vocab_size)]
|
| 1042 |
+
return _standardized_preference_bias(token_bias, active_mask), state_pair_count
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
def _solve_weighted_prompt_readout(
|
| 1046 |
+
states: list[Vector],
|
| 1047 |
+
labels: list[int],
|
| 1048 |
+
weights: list[float],
|
| 1049 |
+
*,
|
| 1050 |
+
vocab_size: int,
|
| 1051 |
+
diagonal: object,
|
| 1052 |
+
state_offset: object,
|
| 1053 |
+
regularization: float,
|
| 1054 |
+
) -> tuple[object, object, int]:
|
| 1055 |
+
if np is None or not states or not labels or not weights:
|
| 1056 |
+
return [], [0.0 for _ in range(vocab_size)], 0
|
| 1057 |
+
state_matrix = np.asarray(states, dtype=np.float64)
|
| 1058 |
+
label_array = np.asarray(labels, dtype=np.int64)
|
| 1059 |
+
weight_vector = np.asarray(weights, dtype=np.float64)
|
| 1060 |
+
valid_mask = (
|
| 1061 |
+
(label_array >= 0)
|
| 1062 |
+
& (label_array < vocab_size)
|
| 1063 |
+
& (weight_vector > 0.0)
|
| 1064 |
+
)
|
| 1065 |
+
if not np.any(valid_mask):
|
| 1066 |
+
return [], [0.0 for _ in range(vocab_size)], 0
|
| 1067 |
+
state_matrix = state_matrix[valid_mask]
|
| 1068 |
+
label_array = label_array[valid_mask]
|
| 1069 |
+
weight_vector = weight_vector[valid_mask]
|
| 1070 |
+
diagonal_array = np.asarray(diagonal, dtype=np.float64)
|
| 1071 |
+
offset_array = np.asarray(state_offset, dtype=np.float64)
|
| 1072 |
+
if (
|
| 1073 |
+
len(state_matrix.shape) != 2
|
| 1074 |
+
or diagonal_array.shape[0] != state_matrix.shape[1]
|
| 1075 |
+
or offset_array.shape[0] != state_matrix.shape[1]
|
| 1076 |
+
):
|
| 1077 |
+
return [], [0.0 for _ in range(vocab_size)], 0
|
| 1078 |
+
masked_states = state_matrix * diagonal_array[None, :]
|
| 1079 |
+
centered_states = masked_states - offset_array[None, :]
|
| 1080 |
+
weighted_centered_states = weight_vector[:, None] * centered_states
|
| 1081 |
+
gram = centered_states.T @ weighted_centered_states
|
| 1082 |
+
cross = np.zeros((vocab_size, centered_states.shape[1]), dtype=np.float64)
|
| 1083 |
+
np.add.at(cross, label_array, weighted_centered_states)
|
| 1084 |
+
total_weight = float(weight_vector.sum())
|
| 1085 |
+
if total_weight <= 0.0:
|
| 1086 |
+
return [], [0.0 for _ in range(vocab_size)], 0
|
| 1087 |
+
bias = np.zeros(vocab_size, dtype=np.float64)
|
| 1088 |
+
np.add.at(bias, label_array, weight_vector)
|
| 1089 |
+
bias /= total_weight
|
| 1090 |
+
readout = ridge_regression_readout_from_moments(
|
| 1091 |
+
gram,
|
| 1092 |
+
cross,
|
| 1093 |
+
regularization=regularization,
|
| 1094 |
+
)
|
| 1095 |
+
return readout, bias, int(label_array.shape[0])
|
| 1096 |
+
|
| 1097 |
+
|
| 1098 |
+
def fit_model_from_corpus_plan(
|
| 1099 |
+
plan: Iterable[CorpusPlanEntry],
|
| 1100 |
+
config: ReframrConfig,
|
| 1101 |
+
*,
|
| 1102 |
+
log_every: int = 0,
|
| 1103 |
+
) -> tuple[ReframrModel, dict[str, object]]:
|
| 1104 |
+
entries = list(plan)
|
| 1105 |
+
if not entries:
|
| 1106 |
+
raise ValueError("Cannot fit REFRAMR without any corpus plan entries.")
|
| 1107 |
+
stage_seconds: dict[str, float] = {}
|
| 1108 |
+
stage_started = time.perf_counter()
|
| 1109 |
+
|
| 1110 |
+
def finish_stage(name: str) -> None:
|
| 1111 |
+
nonlocal stage_started
|
| 1112 |
+
now = time.perf_counter()
|
| 1113 |
+
elapsed = round(now - stage_started, 6)
|
| 1114 |
+
stage_seconds[name] = elapsed
|
| 1115 |
+
if log_every > 0:
|
| 1116 |
+
print(f"[stage] {name} finished in {elapsed:.3f}s")
|
| 1117 |
+
stage_started = now
|
| 1118 |
+
|
| 1119 |
+
seed_tokenizer = NativeTokenizer(
|
| 1120 |
+
merges=[],
|
| 1121 |
+
vocab=[],
|
| 1122 |
+
base_symbols=[],
|
| 1123 |
+
lowercase=config.lowercase,
|
| 1124 |
+
)
|
| 1125 |
+
segment_counts: Counter[str] = Counter()
|
| 1126 |
+
source_counts: dict[str, int] = {}
|
| 1127 |
+
documents: list[StreamDocument] = []
|
| 1128 |
+
processed = 0
|
| 1129 |
+
for entry in entries:
|
| 1130 |
+
if log_every > 0:
|
| 1131 |
+
print(f"[source] {entry.name} started")
|
| 1132 |
+
source_start = processed
|
| 1133 |
+
for document in iter_corpus_plan_documents([entry]):
|
| 1134 |
+
documents.append(document)
|
| 1135 |
+
processed += 1
|
| 1136 |
+
source_counts[document.source] = source_counts.get(document.source, 0) + 1
|
| 1137 |
+
for text_part, part_weight in _weighted_text_parts_for_statistics(
|
| 1138 |
+
document.text,
|
| 1139 |
+
document.weight,
|
| 1140 |
+
):
|
| 1141 |
+
for segment in seed_tokenizer.pretokenize(text_part):
|
| 1142 |
+
segment_counts[segment] += part_weight
|
| 1143 |
+
if document.preference_rejected_text:
|
| 1144 |
+
rejected_weight = document.weight * PREFERENCE_REJECTED_TOKENIZER_WEIGHT
|
| 1145 |
+
for text_part, part_weight in _weighted_text_parts_for_statistics(
|
| 1146 |
+
document.preference_rejected_text,
|
| 1147 |
+
rejected_weight,
|
| 1148 |
+
):
|
| 1149 |
+
for segment in seed_tokenizer.pretokenize(text_part):
|
| 1150 |
+
segment_counts[segment] += part_weight
|
| 1151 |
+
_log_progress("tokenizer", processed, log_every)
|
| 1152 |
+
if log_every > 0:
|
| 1153 |
+
print(f"[source] {entry.name} accepted {processed - source_start} documents")
|
| 1154 |
+
if processed == 0:
|
| 1155 |
+
raise ValueError("Corpus plan did not yield any usable documents after filtering.")
|
| 1156 |
+
finish_stage("stream_and_segment")
|
| 1157 |
+
tokenizer = NativeTokenizer.train_from_segment_counts(
|
| 1158 |
+
segment_counts,
|
| 1159 |
+
vocab_size=config.tokenizer_vocab_size,
|
| 1160 |
+
min_pair_frequency=config.tokenizer_min_pair_frequency,
|
| 1161 |
+
lowercase=config.lowercase,
|
| 1162 |
+
)
|
| 1163 |
+
finish_stage("tokenizer_fit")
|
| 1164 |
+
|
| 1165 |
+
token_counts: Counter[str] = Counter()
|
| 1166 |
+
raw_tokenized_documents: list[list[str]] = []
|
| 1167 |
+
raw_rejected_tokenized_documents: list[list[str]] = []
|
| 1168 |
+
processed = 0
|
| 1169 |
+
for document in documents:
|
| 1170 |
+
processed += 1
|
| 1171 |
+
tokens = tokenizer.encode(document.text)
|
| 1172 |
+
raw_tokenized_documents.append(tokens)
|
| 1173 |
+
for token in tokens:
|
| 1174 |
+
if token in tokenizer.special_tokens:
|
| 1175 |
+
token_counts[token] += document.weight
|
| 1176 |
+
for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
|
| 1177 |
+
tokens,
|
| 1178 |
+
tokenizer,
|
| 1179 |
+
document.weight,
|
| 1180 |
+
):
|
| 1181 |
+
for token in token_sequence:
|
| 1182 |
+
token_counts[token] += sequence_weight
|
| 1183 |
+
rejected_tokens = (
|
| 1184 |
+
tokenizer.encode(document.preference_rejected_text)
|
| 1185 |
+
if document.preference_rejected_text
|
| 1186 |
+
else []
|
| 1187 |
+
)
|
| 1188 |
+
raw_rejected_tokenized_documents.append(rejected_tokens)
|
| 1189 |
+
rejected_weight = document.weight * PREFERENCE_REJECTED_TOKENIZER_WEIGHT
|
| 1190 |
+
for token in rejected_tokens:
|
| 1191 |
+
if token in tokenizer.special_tokens:
|
| 1192 |
+
token_counts[token] += rejected_weight
|
| 1193 |
+
for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
|
| 1194 |
+
rejected_tokens,
|
| 1195 |
+
tokenizer,
|
| 1196 |
+
rejected_weight,
|
| 1197 |
+
):
|
| 1198 |
+
for token in token_sequence:
|
| 1199 |
+
token_counts[token] += sequence_weight
|
| 1200 |
+
_log_progress("vocab", processed, log_every)
|
| 1201 |
+
token_to_id, id_to_token = build_vocabulary_from_counts(
|
| 1202 |
+
token_counts,
|
| 1203 |
+
min_frequency=config.min_frequency,
|
| 1204 |
+
max_vocab=config.max_vocab,
|
| 1205 |
+
)
|
| 1206 |
+
if not id_to_token:
|
| 1207 |
+
raise ValueError("Streaming recompute could not derive an embedding vocabulary.")
|
| 1208 |
+
finish_stage("vocabulary")
|
| 1209 |
+
|
| 1210 |
+
cooccurrence = StreamingCooccurrenceAccumulator(token_to_id, config.window_size)
|
| 1211 |
+
tokenized_documents: list[list[str]] = []
|
| 1212 |
+
preference_token_pairs: list[tuple[list[str], list[str], float]] = []
|
| 1213 |
+
processed = 0
|
| 1214 |
+
for document, raw_tokens, raw_rejected_tokens in zip(
|
| 1215 |
+
documents,
|
| 1216 |
+
raw_tokenized_documents,
|
| 1217 |
+
raw_rejected_tokenized_documents,
|
| 1218 |
+
):
|
| 1219 |
+
processed += 1
|
| 1220 |
+
tokens = [token for token in raw_tokens if token in token_to_id]
|
| 1221 |
+
tokenized_documents.append(tokens)
|
| 1222 |
+
rejected_tokens = [token for token in raw_rejected_tokens if token in token_to_id]
|
| 1223 |
+
if len(tokens) > 1 and len(rejected_tokens) > 1:
|
| 1224 |
+
preference_token_pairs.append((tokens, rejected_tokens, document.weight))
|
| 1225 |
+
for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
|
| 1226 |
+
tokens,
|
| 1227 |
+
tokenizer,
|
| 1228 |
+
document.weight,
|
| 1229 |
+
):
|
| 1230 |
+
if len(token_sequence) > 1:
|
| 1231 |
+
cooccurrence.update_tokens(token_sequence, weight=sequence_weight)
|
| 1232 |
+
_log_progress("cooccurrence", processed, log_every)
|
| 1233 |
+
finish_stage("cooccurrence")
|
| 1234 |
+
if np is not None:
|
| 1235 |
+
embedding_model = fit_randomized_ppmi_embedding_from_counts(
|
| 1236 |
+
id_to_token,
|
| 1237 |
+
cooccurrence.rows,
|
| 1238 |
+
embedding_dim=config.embedding_dim,
|
| 1239 |
+
)
|
| 1240 |
+
else:
|
| 1241 |
+
embedding_model = fit_ppmi_embedding_from_cooccurrence(
|
| 1242 |
+
id_to_token,
|
| 1243 |
+
cooccurrence.to_sparse(),
|
| 1244 |
+
embedding_dim=config.embedding_dim,
|
| 1245 |
+
)
|
| 1246 |
+
finish_stage("embedding")
|
| 1247 |
+
|
| 1248 |
+
model = ReframrModel(config)
|
| 1249 |
+
model.tokenizer = tokenizer
|
| 1250 |
+
model.embedding_model = embedding_model
|
| 1251 |
+
model.memory_units = [
|
| 1252 |
+
AnalyticalMemoryUnit(config.state_dim, timescale)
|
| 1253 |
+
for timescale in config.timescales
|
| 1254 |
+
]
|
| 1255 |
+
model.trace_token_weights = model._derive_trace_token_weights_from_counts(token_counts)
|
| 1256 |
+
|
| 1257 |
+
feature_count = len(model._zero_combined_state())
|
| 1258 |
+
if np is not None:
|
| 1259 |
+
feature_second_moment = np.zeros(feature_count, dtype=np.float64)
|
| 1260 |
+
raw_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64)
|
| 1261 |
+
else:
|
| 1262 |
+
feature_second_moment = zeros_vector(feature_count)
|
| 1263 |
+
raw_cross = zeros(len(embedding_model.id_to_token), feature_count)
|
| 1264 |
+
example_weight_total = 0.0
|
| 1265 |
+
has_answer_targets = any(_answer_boundary(tokens) is not None for tokens in tokenized_documents)
|
| 1266 |
+
if config.max_training_examples is None:
|
| 1267 |
+
answer_reservoir_capacity = None
|
| 1268 |
+
general_reservoir_capacity = None
|
| 1269 |
+
elif config.max_training_examples <= 0:
|
| 1270 |
+
answer_reservoir_capacity = 0
|
| 1271 |
+
general_reservoir_capacity = 0
|
| 1272 |
+
elif has_answer_targets:
|
| 1273 |
+
answer_reservoir_capacity = max(1, int(config.max_training_examples * 0.75))
|
| 1274 |
+
general_reservoir_capacity = max(0, config.max_training_examples - answer_reservoir_capacity)
|
| 1275 |
+
else:
|
| 1276 |
+
answer_reservoir_capacity = 0
|
| 1277 |
+
general_reservoir_capacity = config.max_training_examples
|
| 1278 |
+
answer_sequence_capacity = MAX_ANSWER_SEQUENCE_EXAMPLES if has_answer_targets else 0
|
| 1279 |
+
answer_reservoir = StateReservoir(answer_reservoir_capacity, seed=17)
|
| 1280 |
+
general_reservoir = StateReservoir(general_reservoir_capacity, seed=13)
|
| 1281 |
+
answer_intent_reservoir = StateReservoir(answer_reservoir_capacity, seed=29)
|
| 1282 |
+
answer_start_reservoir = StateReservoir(answer_reservoir_capacity, seed=37)
|
| 1283 |
+
answer_sequence_reservoir = SequenceReservoir(answer_sequence_capacity, seed=41)
|
| 1284 |
+
moment_reservoir = StateReservoir(
|
| 1285 |
+
config.max_training_examples if config.max_training_examples is not None else None,
|
| 1286 |
+
seed=31,
|
| 1287 |
+
)
|
| 1288 |
+
transitions = TransitionAccumulator(
|
| 1289 |
+
max_contexts_per_order=config.max_transition_contexts_per_order,
|
| 1290 |
+
max_next_tokens=config.max_transition_next_tokens,
|
| 1291 |
+
)
|
| 1292 |
+
if np is not None:
|
| 1293 |
+
target_label_mass = np.zeros(len(embedding_model.id_to_token), dtype=np.float64)
|
| 1294 |
+
else:
|
| 1295 |
+
target_label_mass = zeros_vector(len(embedding_model.id_to_token))
|
| 1296 |
+
for document, tokens in zip(documents, tokenized_documents):
|
| 1297 |
+
answer_index = _answer_boundary(tokens)
|
| 1298 |
+
for index in range(len(tokens) - 1):
|
| 1299 |
+
next_token = tokens[index + 1]
|
| 1300 |
+
if tokenizer is not None and next_token in tokenizer.special_tokens:
|
| 1301 |
+
continue
|
| 1302 |
+
next_token_id = embedding_model.token_to_id.get(next_token, -1)
|
| 1303 |
+
if next_token_id < 0:
|
| 1304 |
+
continue
|
| 1305 |
+
label_weight = _readout_weight_for_target(answer_index, index + 1, document.weight)
|
| 1306 |
+
if label_weight > 0.0:
|
| 1307 |
+
target_label_mass[next_token_id] += label_weight
|
| 1308 |
+
if np is not None:
|
| 1309 |
+
positive_label_mass = target_label_mass[target_label_mass > 0.0]
|
| 1310 |
+
reference_label_mass = (
|
| 1311 |
+
float(np.median(positive_label_mass))
|
| 1312 |
+
if positive_label_mass.size
|
| 1313 |
+
else 1.0
|
| 1314 |
+
)
|
| 1315 |
+
target_balance = np.ones(len(embedding_model.id_to_token), dtype=np.float64)
|
| 1316 |
+
np.divide(
|
| 1317 |
+
reference_label_mass,
|
| 1318 |
+
np.maximum(target_label_mass, 1e-12),
|
| 1319 |
+
out=target_balance,
|
| 1320 |
+
where=target_label_mass > 0.0,
|
| 1321 |
+
)
|
| 1322 |
+
target_balance = np.clip(np.sqrt(target_balance), 0.25, 4.0)
|
| 1323 |
+
else:
|
| 1324 |
+
positive_label_mass = [value for value in target_label_mass if value > 0.0]
|
| 1325 |
+
if positive_label_mass:
|
| 1326 |
+
sorted_mass = sorted(positive_label_mass)
|
| 1327 |
+
reference_label_mass = sorted_mass[len(sorted_mass) // 2]
|
| 1328 |
+
else:
|
| 1329 |
+
reference_label_mass = 1.0
|
| 1330 |
+
target_balance = [
|
| 1331 |
+
max(0.25, min(4.0, (reference_label_mass / max(value, 1e-12)) ** 0.5))
|
| 1332 |
+
if value > 0.0
|
| 1333 |
+
else 1.0
|
| 1334 |
+
for value in target_label_mass
|
| 1335 |
+
]
|
| 1336 |
+
processed = 0
|
| 1337 |
+
embedding_array = (
|
| 1338 |
+
np.asarray(embedding_model.embeddings, dtype=RUNTIME_ARRAY_DTYPE)
|
| 1339 |
+
if np is not None
|
| 1340 |
+
else None
|
| 1341 |
+
)
|
| 1342 |
+
trace_embedding_array = (
|
| 1343 |
+
model._build_trace_embedding_table_array(embedding_array)
|
| 1344 |
+
if np is not None and embedding_array is not None
|
| 1345 |
+
else None
|
| 1346 |
+
)
|
| 1347 |
+
if np is not None:
|
| 1348 |
+
trace_decay = np.asarray(
|
| 1349 |
+
[1.0 / (1.0 + unit.timescale) for unit in model.memory_units],
|
| 1350 |
+
dtype=RUNTIME_ARRAY_DTYPE,
|
| 1351 |
+
)
|
| 1352 |
+
trace_gain = 1.0 - trace_decay
|
| 1353 |
+
transition_stack = np.asarray(
|
| 1354 |
+
[unit.transition for unit in model.memory_units],
|
| 1355 |
+
dtype=RUNTIME_ARRAY_DTYPE,
|
| 1356 |
+
)
|
| 1357 |
+
input_projection_stack = np.asarray(
|
| 1358 |
+
[unit.input_projection for unit in model.memory_units],
|
| 1359 |
+
dtype=RUNTIME_ARRAY_DTYPE,
|
| 1360 |
+
)
|
| 1361 |
+
drive_indices = np.arange(config.state_dim, dtype=np.int64)
|
| 1362 |
+
drive_primary = drive_indices % config.embedding_dim
|
| 1363 |
+
drive_secondary = (3 * drive_indices + 1) % config.embedding_dim
|
| 1364 |
+
drive_tertiary = (5 * drive_indices + 2) % config.embedding_dim
|
| 1365 |
+
else:
|
| 1366 |
+
trace_decay = None
|
| 1367 |
+
trace_gain = None
|
| 1368 |
+
transition_stack = None
|
| 1369 |
+
input_projection_stack = None
|
| 1370 |
+
drive_primary = None
|
| 1371 |
+
drive_secondary = None
|
| 1372 |
+
drive_tertiary = None
|
| 1373 |
+
for document, tokens in zip(documents, tokenized_documents):
|
| 1374 |
+
processed += 1
|
| 1375 |
+
if len(tokens) < 2:
|
| 1376 |
+
_log_progress("state", processed, log_every)
|
| 1377 |
+
continue
|
| 1378 |
+
|
| 1379 |
+
answer_index = _answer_boundary(tokens)
|
| 1380 |
+
for token_sequence, sequence_weight in _weighted_token_sequences_for_statistics(
|
| 1381 |
+
tokens,
|
| 1382 |
+
tokenizer,
|
| 1383 |
+
document.weight,
|
| 1384 |
+
):
|
| 1385 |
+
if len(token_sequence) > 1:
|
| 1386 |
+
transitions.update_tokens(token_sequence, weight=sequence_weight)
|
| 1387 |
+
if np is not None:
|
| 1388 |
+
hidden_state_matrix = np.zeros((len(config.timescales), config.state_dim), dtype=RUNTIME_ARRAY_DTYPE)
|
| 1389 |
+
context_trace_matrix = np.zeros((len(config.timescales), config.embedding_dim), dtype=RUNTIME_ARRAY_DTYPE)
|
| 1390 |
+
else:
|
| 1391 |
+
hidden_states = [zeros_vector(config.state_dim) for _ in config.timescales]
|
| 1392 |
+
context_traces = [zeros_vector(config.embedding_dim) for _ in config.timescales]
|
| 1393 |
+
answer_anchor_state = None
|
| 1394 |
+
for index in range(len(tokens) - 1):
|
| 1395 |
+
token = tokens[index]
|
| 1396 |
+
token_id = embedding_model.token_to_id.get(token, -1)
|
| 1397 |
+
if (
|
| 1398 |
+
np is not None
|
| 1399 |
+
and embedding_array is not None
|
| 1400 |
+
and trace_decay is not None
|
| 1401 |
+
and trace_gain is not None
|
| 1402 |
+
and transition_stack is not None
|
| 1403 |
+
and input_projection_stack is not None
|
| 1404 |
+
and drive_primary is not None
|
| 1405 |
+
and drive_secondary is not None
|
| 1406 |
+
and drive_tertiary is not None
|
| 1407 |
+
and trace_embedding_array is not None
|
| 1408 |
+
and token_id >= 0
|
| 1409 |
+
):
|
| 1410 |
+
embedding = embedding_array[token_id]
|
| 1411 |
+
trace_embedding = trace_embedding_array[token_id]
|
| 1412 |
+
drive = (
|
| 1413 |
+
embedding[drive_primary]
|
| 1414 |
+
+ (0.5 * embedding[drive_secondary])
|
| 1415 |
+
- (0.25 * embedding[drive_tertiary])
|
| 1416 |
+
)
|
| 1417 |
+
hidden_state_matrix = (
|
| 1418 |
+
(transition_stack @ hidden_state_matrix[:, :, None])[:, :, 0]
|
| 1419 |
+
+ (input_projection_stack * drive[None, :])
|
| 1420 |
+
)
|
| 1421 |
+
context_trace_matrix = (
|
| 1422 |
+
context_trace_matrix + (trace_gain[:, None] * trace_embedding[None, :])
|
| 1423 |
+
)
|
| 1424 |
+
else:
|
| 1425 |
+
hidden_states, context_traces, combined_state = model._step_hidden_states(
|
| 1426 |
+
hidden_states,
|
| 1427 |
+
context_traces,
|
| 1428 |
+
token,
|
| 1429 |
+
)
|
| 1430 |
+
if token == "<answer>":
|
| 1431 |
+
if np is not None:
|
| 1432 |
+
answer_anchor_state = np.concatenate(
|
| 1433 |
+
(hidden_state_matrix, context_trace_matrix),
|
| 1434 |
+
axis=1,
|
| 1435 |
+
).reshape(-1).copy()
|
| 1436 |
+
else:
|
| 1437 |
+
answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:]
|
| 1438 |
+
next_token = tokens[index + 1]
|
| 1439 |
+
if next_token in tokenizer.special_tokens:
|
| 1440 |
+
continue
|
| 1441 |
+
next_token_id = embedding_model.token_to_id.get(next_token, -1)
|
| 1442 |
+
if next_token_id < 0:
|
| 1443 |
+
continue
|
| 1444 |
+
raw_readout_weight = _readout_weight_for_target(answer_index, index + 1, document.weight)
|
| 1445 |
+
readout_weight = raw_readout_weight * float(target_balance[next_token_id])
|
| 1446 |
+
if readout_weight <= 0.0:
|
| 1447 |
+
continue
|
| 1448 |
+
moment_slot = moment_reservoir.reserve_slot(weight=readout_weight)
|
| 1449 |
+
is_answer_target = answer_index is not None and index + 1 > answer_index
|
| 1450 |
+
target_reservoir = answer_reservoir if is_answer_target else general_reservoir
|
| 1451 |
+
memory_weight = readout_weight * float(target_balance[next_token_id])
|
| 1452 |
+
answer_token_offset = (
|
| 1453 |
+
index - answer_index
|
| 1454 |
+
if is_answer_target and answer_index is not None
|
| 1455 |
+
else None
|
| 1456 |
+
)
|
| 1457 |
+
intent_slot = (
|
| 1458 |
+
answer_intent_reservoir.reserve_slot(weight=memory_weight)
|
| 1459 |
+
if is_answer_target and answer_anchor_state is not None
|
| 1460 |
+
else None
|
| 1461 |
+
)
|
| 1462 |
+
answer_start_weight = (
|
| 1463 |
+
raw_readout_weight * (ANSWER_START_DECAY ** answer_token_offset)
|
| 1464 |
+
if (
|
| 1465 |
+
answer_token_offset is not None
|
| 1466 |
+
and answer_token_offset < ANSWER_START_TOKEN_WINDOW
|
| 1467 |
+
)
|
| 1468 |
+
else 0.0
|
| 1469 |
+
)
|
| 1470 |
+
answer_start_slot = (
|
| 1471 |
+
answer_start_reservoir.reserve_slot(weight=answer_start_weight)
|
| 1472 |
+
if answer_start_weight > 0.0 and answer_anchor_state is not None
|
| 1473 |
+
else None
|
| 1474 |
+
)
|
| 1475 |
+
if np is not None:
|
| 1476 |
+
reservoir_slot = target_reservoir.reserve_slot(weight=memory_weight)
|
| 1477 |
+
if moment_slot is not None or reservoir_slot is not None:
|
| 1478 |
+
combined_state = np.concatenate(
|
| 1479 |
+
(hidden_state_matrix, context_trace_matrix),
|
| 1480 |
+
axis=1,
|
| 1481 |
+
).reshape(-1).copy()
|
| 1482 |
+
if moment_slot is not None:
|
| 1483 |
+
moment_reservoir.store_reserved(
|
| 1484 |
+
moment_slot,
|
| 1485 |
+
combined_state,
|
| 1486 |
+
next_token_id,
|
| 1487 |
+
example_weight=readout_weight,
|
| 1488 |
+
)
|
| 1489 |
+
if reservoir_slot is not None:
|
| 1490 |
+
target_reservoir.store_reserved(reservoir_slot, combined_state, next_token_id)
|
| 1491 |
+
if intent_slot is not None:
|
| 1492 |
+
answer_intent_reservoir.store_reserved(
|
| 1493 |
+
intent_slot,
|
| 1494 |
+
answer_anchor_state,
|
| 1495 |
+
next_token_id,
|
| 1496 |
+
example_weight=memory_weight,
|
| 1497 |
+
)
|
| 1498 |
+
if answer_start_slot is not None:
|
| 1499 |
+
answer_start_reservoir.store_reserved(
|
| 1500 |
+
answer_start_slot,
|
| 1501 |
+
answer_anchor_state,
|
| 1502 |
+
next_token_id,
|
| 1503 |
+
example_weight=answer_start_weight * float(target_balance[next_token_id]),
|
| 1504 |
+
)
|
| 1505 |
+
else:
|
| 1506 |
+
reservoir_slot = target_reservoir.reserve_slot(weight=memory_weight)
|
| 1507 |
+
if moment_slot is None and reservoir_slot is None:
|
| 1508 |
+
continue
|
| 1509 |
+
if moment_slot is not None:
|
| 1510 |
+
moment_reservoir.store_reserved(
|
| 1511 |
+
moment_slot,
|
| 1512 |
+
combined_state,
|
| 1513 |
+
next_token_id,
|
| 1514 |
+
example_weight=readout_weight,
|
| 1515 |
+
)
|
| 1516 |
+
if reservoir_slot is not None:
|
| 1517 |
+
target_reservoir.store_reserved(reservoir_slot, combined_state, next_token_id)
|
| 1518 |
+
if intent_slot is not None:
|
| 1519 |
+
answer_intent_reservoir.store_reserved(
|
| 1520 |
+
intent_slot,
|
| 1521 |
+
answer_anchor_state,
|
| 1522 |
+
next_token_id,
|
| 1523 |
+
example_weight=memory_weight,
|
| 1524 |
+
)
|
| 1525 |
+
if answer_start_slot is not None:
|
| 1526 |
+
answer_start_reservoir.store_reserved(
|
| 1527 |
+
answer_start_slot,
|
| 1528 |
+
answer_anchor_state,
|
| 1529 |
+
next_token_id,
|
| 1530 |
+
example_weight=answer_start_weight * target_balance[next_token_id],
|
| 1531 |
+
)
|
| 1532 |
+
if answer_anchor_state is not None and answer_index is not None:
|
| 1533 |
+
prompt_token_ids = [
|
| 1534 |
+
embedding_model.token_to_id[token]
|
| 1535 |
+
for token in tokens[:answer_index]
|
| 1536 |
+
if token not in tokenizer.special_tokens
|
| 1537 |
+
and token in embedding_model.token_to_id
|
| 1538 |
+
]
|
| 1539 |
+
answer_token_ids = [
|
| 1540 |
+
embedding_model.token_to_id[token]
|
| 1541 |
+
for token in tokens[answer_index + 1 :]
|
| 1542 |
+
if token not in tokenizer.special_tokens
|
| 1543 |
+
and token in embedding_model.token_to_id
|
| 1544 |
+
]
|
| 1545 |
+
answer_sequence_reservoir.consider(
|
| 1546 |
+
answer_anchor_state,
|
| 1547 |
+
prompt_token_ids,
|
| 1548 |
+
answer_token_ids,
|
| 1549 |
+
weight=document.weight * ANSWER_READOUT_WEIGHT,
|
| 1550 |
+
)
|
| 1551 |
+
_log_progress("state", processed, log_every)
|
| 1552 |
+
|
| 1553 |
+
moment_states = moment_reservoir.states
|
| 1554 |
+
moment_labels = moment_reservoir.labels
|
| 1555 |
+
moment_weights = moment_reservoir.weights
|
| 1556 |
+
example_weight_total = sum(moment_weights)
|
| 1557 |
+
if np is not None and moment_states:
|
| 1558 |
+
state_matrix = np.asarray(moment_states, dtype=np.float64)
|
| 1559 |
+
weight_vector = np.asarray(moment_weights, dtype=np.float64)
|
| 1560 |
+
weighted_states = weight_vector[:, None] * state_matrix
|
| 1561 |
+
feature_second_moment += (weighted_states * state_matrix).sum(axis=0)
|
| 1562 |
+
np.add.at(raw_cross, moment_labels, weighted_states)
|
| 1563 |
+
elif moment_states:
|
| 1564 |
+
for state, label_id, readout_weight in zip(moment_states, moment_labels, moment_weights):
|
| 1565 |
+
for feature, value in enumerate(state):
|
| 1566 |
+
weighted_value = readout_weight * value
|
| 1567 |
+
feature_second_moment[feature] += weighted_value * value
|
| 1568 |
+
raw_cross[label_id][feature] += weighted_value
|
| 1569 |
+
|
| 1570 |
+
if example_weight_total <= 0.0:
|
| 1571 |
+
raise ValueError("Streaming recompute did not collect any next-token training examples.")
|
| 1572 |
+
|
| 1573 |
+
if np is not None:
|
| 1574 |
+
feature_energy = (feature_second_moment / example_weight_total).tolist()
|
| 1575 |
+
else:
|
| 1576 |
+
feature_energy = [
|
| 1577 |
+
feature_second_moment[index] / example_weight_total
|
| 1578 |
+
for index in range(feature_count)
|
| 1579 |
+
]
|
| 1580 |
+
ternary_scale, ternary_mask = derive_ternary_mask_from_feature_energy(feature_energy)
|
| 1581 |
+
if np is not None:
|
| 1582 |
+
diagonal = np.asarray([ternary_scale * value for value in ternary_mask], dtype=np.float64)
|
| 1583 |
+
masked_feature_second_moment = feature_second_moment * diagonal * diagonal
|
| 1584 |
+
masked_cross = raw_cross * diagonal[None, :]
|
| 1585 |
+
else:
|
| 1586 |
+
diagonal = [ternary_scale * value for value in ternary_mask]
|
| 1587 |
+
masked_feature_second_moment = [
|
| 1588 |
+
feature_second_moment[index] * diagonal[index] * diagonal[index]
|
| 1589 |
+
for index in range(feature_count)
|
| 1590 |
+
]
|
| 1591 |
+
masked_cross = [
|
| 1592 |
+
[
|
| 1593 |
+
raw_cross[row][col] * diagonal[col]
|
| 1594 |
+
for col in range(feature_count)
|
| 1595 |
+
]
|
| 1596 |
+
for row in range(len(raw_cross))
|
| 1597 |
+
]
|
| 1598 |
+
readout_solver = "diagonal"
|
| 1599 |
+
state_offset_values: object
|
| 1600 |
+
readout_bias_values: object
|
| 1601 |
+
if (
|
| 1602 |
+
np is not None
|
| 1603 |
+
and moment_states
|
| 1604 |
+
and feature_count <= FULL_READOUT_FEATURE_LIMIT
|
| 1605 |
+
and len(moment_states) <= FULL_READOUT_EXAMPLE_LIMIT
|
| 1606 |
+
):
|
| 1607 |
+
state_matrix = np.asarray(moment_states, dtype=np.float64)
|
| 1608 |
+
weight_vector = np.asarray(moment_weights, dtype=np.float64)
|
| 1609 |
+
label_array = np.asarray(moment_labels, dtype=np.int64)
|
| 1610 |
+
masked_states = state_matrix * diagonal[None, :]
|
| 1611 |
+
total_weight = float(weight_vector.sum())
|
| 1612 |
+
if total_weight <= 0.0:
|
| 1613 |
+
total_weight = 1.0
|
| 1614 |
+
state_offset_values = (weight_vector[:, None] * masked_states).sum(axis=0) / total_weight
|
| 1615 |
+
centered_states = masked_states - state_offset_values[None, :]
|
| 1616 |
+
weighted_centered_states = weight_vector[:, None] * centered_states
|
| 1617 |
+
gram = centered_states.T @ weighted_centered_states
|
| 1618 |
+
full_cross = np.zeros((len(embedding_model.id_to_token), feature_count), dtype=np.float64)
|
| 1619 |
+
np.add.at(full_cross, label_array, weighted_centered_states)
|
| 1620 |
+
readout_bias_values = np.zeros(len(embedding_model.id_to_token), dtype=np.float64)
|
| 1621 |
+
np.add.at(readout_bias_values, label_array, weight_vector)
|
| 1622 |
+
readout_bias_values /= total_weight
|
| 1623 |
+
readout_weights = ridge_regression_readout_from_moments(
|
| 1624 |
+
gram,
|
| 1625 |
+
full_cross,
|
| 1626 |
+
regularization=config.regularization,
|
| 1627 |
+
)
|
| 1628 |
+
readout_solver = "full"
|
| 1629 |
+
else:
|
| 1630 |
+
state_offset_values = (
|
| 1631 |
+
np.zeros(feature_count, dtype=np.float64)
|
| 1632 |
+
if np is not None
|
| 1633 |
+
else [0.0 for _ in range(feature_count)]
|
| 1634 |
+
)
|
| 1635 |
+
if np is not None:
|
| 1636 |
+
label_total = max(float(target_label_mass.sum()), 1.0)
|
| 1637 |
+
readout_bias_values = target_label_mass / label_total
|
| 1638 |
+
else:
|
| 1639 |
+
label_total = max(sum(target_label_mass), 1.0)
|
| 1640 |
+
readout_bias_values = [value / label_total for value in target_label_mass]
|
| 1641 |
+
readout_weights = ridge_regression_readout_from_diagonal_moments(
|
| 1642 |
+
masked_feature_second_moment,
|
| 1643 |
+
masked_cross,
|
| 1644 |
+
regularization=config.regularization,
|
| 1645 |
+
)
|
| 1646 |
+
finish_stage("state_and_readout")
|
| 1647 |
+
|
| 1648 |
+
model.ternary_scale = ternary_scale
|
| 1649 |
+
model.ternary_mask = ternary_mask
|
| 1650 |
+
model.readout_weights = readout_weights
|
| 1651 |
+
model.state_offset = (
|
| 1652 |
+
state_offset_values.tolist()
|
| 1653 |
+
if hasattr(state_offset_values, "tolist")
|
| 1654 |
+
else list(state_offset_values)
|
| 1655 |
+
)
|
| 1656 |
+
model.readout_bias = (
|
| 1657 |
+
readout_bias_values.tolist()
|
| 1658 |
+
if hasattr(readout_bias_values, "tolist")
|
| 1659 |
+
else list(readout_bias_values)
|
| 1660 |
+
)
|
| 1661 |
+
model.preference_bias, preference_state_pairs = _derive_preference_bias_from_pairs(
|
| 1662 |
+
model,
|
| 1663 |
+
preference_token_pairs,
|
| 1664 |
+
tokenizer,
|
| 1665 |
+
)
|
| 1666 |
+
finish_stage("preference")
|
| 1667 |
+
reservoir_states = answer_reservoir.states + general_reservoir.states
|
| 1668 |
+
reservoir_labels = answer_reservoir.labels + general_reservoir.labels
|
| 1669 |
+
answer_intent_states = answer_intent_reservoir.states
|
| 1670 |
+
answer_intent_labels = answer_intent_reservoir.labels
|
| 1671 |
+
answer_start_states = answer_start_reservoir.states
|
| 1672 |
+
answer_start_labels = answer_start_reservoir.labels
|
| 1673 |
+
answer_sequence_states = answer_sequence_reservoir.keys
|
| 1674 |
+
answer_sequence_prompt_rows = answer_sequence_reservoir.prompt_rows
|
| 1675 |
+
answer_sequence_rows = answer_sequence_reservoir.token_rows
|
| 1676 |
+
prompt_answer_weights, prompt_answer_bias, prompt_answer_readout_examples = (
|
| 1677 |
+
_solve_weighted_prompt_readout(
|
| 1678 |
+
answer_intent_states,
|
| 1679 |
+
answer_intent_labels,
|
| 1680 |
+
answer_intent_reservoir.weights,
|
| 1681 |
+
vocab_size=len(embedding_model.id_to_token),
|
| 1682 |
+
diagonal=diagonal,
|
| 1683 |
+
state_offset=state_offset_values,
|
| 1684 |
+
regularization=config.regularization,
|
| 1685 |
+
)
|
| 1686 |
+
)
|
| 1687 |
+
(
|
| 1688 |
+
prompt_answer_start_weights,
|
| 1689 |
+
prompt_answer_start_bias,
|
| 1690 |
+
prompt_answer_start_readout_examples,
|
| 1691 |
+
) = _solve_weighted_prompt_readout(
|
| 1692 |
+
answer_start_states,
|
| 1693 |
+
answer_start_labels,
|
| 1694 |
+
answer_start_reservoir.weights,
|
| 1695 |
+
vocab_size=len(embedding_model.id_to_token),
|
| 1696 |
+
diagonal=diagonal,
|
| 1697 |
+
state_offset=state_offset_values,
|
| 1698 |
+
regularization=config.regularization,
|
| 1699 |
+
)
|
| 1700 |
+
model.prompt_answer_weights = prompt_answer_weights
|
| 1701 |
+
model.prompt_answer_bias = (
|
| 1702 |
+
prompt_answer_bias.tolist()
|
| 1703 |
+
if hasattr(prompt_answer_bias, "tolist")
|
| 1704 |
+
else list(prompt_answer_bias)
|
| 1705 |
+
)
|
| 1706 |
+
model.prompt_answer_start_weights = prompt_answer_start_weights
|
| 1707 |
+
model.prompt_answer_start_bias = (
|
| 1708 |
+
prompt_answer_start_bias.tolist()
|
| 1709 |
+
if hasattr(prompt_answer_start_bias, "tolist")
|
| 1710 |
+
else list(prompt_answer_start_bias)
|
| 1711 |
+
)
|
| 1712 |
+
if np is not None and reservoir_states:
|
| 1713 |
+
reservoir_array = np.asarray(reservoir_states, dtype=RUNTIME_ARRAY_DTYPE)
|
| 1714 |
+
mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
|
| 1715 |
+
offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
|
| 1716 |
+
associative_array = ((reservoir_array * mask_array[None, :]) - offset_array[None, :]).astype(
|
| 1717 |
+
RUNTIME_ARRAY_DTYPE,
|
| 1718 |
+
copy=False,
|
| 1719 |
+
)
|
| 1720 |
+
model.associative_keys = associative_array
|
| 1721 |
+
model.associative_key_norms = np.linalg.norm(associative_array, axis=1).tolist()
|
| 1722 |
+
else:
|
| 1723 |
+
offset_vector = model.state_offset
|
| 1724 |
+
model.associative_keys = [
|
| 1725 |
+
[
|
| 1726 |
+
value - offset_vector[index]
|
| 1727 |
+
for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
|
| 1728 |
+
]
|
| 1729 |
+
for state in reservoir_states
|
| 1730 |
+
]
|
| 1731 |
+
model.associative_key_norms = [norm(state) for state in model.associative_keys]
|
| 1732 |
+
model.associative_values = reservoir_labels[:]
|
| 1733 |
+
if np is not None and answer_intent_states:
|
| 1734 |
+
answer_intent_array = np.asarray(answer_intent_states, dtype=RUNTIME_ARRAY_DTYPE)
|
| 1735 |
+
mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
|
| 1736 |
+
offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
|
| 1737 |
+
answer_array = ((answer_intent_array * mask_array[None, :]) - offset_array[None, :]).astype(
|
| 1738 |
+
RUNTIME_ARRAY_DTYPE,
|
| 1739 |
+
copy=False,
|
| 1740 |
+
)
|
| 1741 |
+
model.answer_keys = answer_array
|
| 1742 |
+
model.answer_key_norms = np.linalg.norm(answer_array, axis=1).tolist()
|
| 1743 |
+
else:
|
| 1744 |
+
offset_vector = model.state_offset
|
| 1745 |
+
model.answer_keys = [
|
| 1746 |
+
[
|
| 1747 |
+
value - offset_vector[index]
|
| 1748 |
+
for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
|
| 1749 |
+
]
|
| 1750 |
+
for state in answer_intent_states
|
| 1751 |
+
]
|
| 1752 |
+
model.answer_key_norms = [norm(state) for state in model.answer_keys]
|
| 1753 |
+
model.answer_values = answer_intent_labels[:]
|
| 1754 |
+
if np is not None and answer_start_states:
|
| 1755 |
+
answer_start_array = np.asarray(answer_start_states, dtype=RUNTIME_ARRAY_DTYPE)
|
| 1756 |
+
mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
|
| 1757 |
+
offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
|
| 1758 |
+
start_array = ((answer_start_array * mask_array[None, :]) - offset_array[None, :]).astype(
|
| 1759 |
+
RUNTIME_ARRAY_DTYPE,
|
| 1760 |
+
copy=False,
|
| 1761 |
+
)
|
| 1762 |
+
model.answer_start_keys = start_array
|
| 1763 |
+
model.answer_start_key_norms = np.linalg.norm(start_array, axis=1).tolist()
|
| 1764 |
+
else:
|
| 1765 |
+
offset_vector = model.state_offset
|
| 1766 |
+
model.answer_start_keys = [
|
| 1767 |
+
[
|
| 1768 |
+
value - offset_vector[index]
|
| 1769 |
+
for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
|
| 1770 |
+
]
|
| 1771 |
+
for state in answer_start_states
|
| 1772 |
+
]
|
| 1773 |
+
model.answer_start_key_norms = [norm(state) for state in model.answer_start_keys]
|
| 1774 |
+
model.answer_start_values = answer_start_labels[:]
|
| 1775 |
+
if np is not None and answer_sequence_states:
|
| 1776 |
+
answer_sequence_array = np.asarray(answer_sequence_states, dtype=RUNTIME_ARRAY_DTYPE)
|
| 1777 |
+
mask_array = np.asarray(ternary_mask, dtype=RUNTIME_ARRAY_DTYPE) * ternary_scale
|
| 1778 |
+
offset_array = np.asarray(model.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
|
| 1779 |
+
sequence_array = ((answer_sequence_array * mask_array[None, :]) - offset_array[None, :]).astype(
|
| 1780 |
+
RUNTIME_ARRAY_DTYPE,
|
| 1781 |
+
copy=False,
|
| 1782 |
+
)
|
| 1783 |
+
model.answer_sequence_keys = sequence_array
|
| 1784 |
+
model.answer_sequence_key_norms = np.linalg.norm(sequence_array, axis=1).tolist()
|
| 1785 |
+
else:
|
| 1786 |
+
offset_vector = model.state_offset
|
| 1787 |
+
model.answer_sequence_keys = [
|
| 1788 |
+
[
|
| 1789 |
+
value - offset_vector[index]
|
| 1790 |
+
for index, value in enumerate(apply_ternary_mask(state, ternary_mask, ternary_scale))
|
| 1791 |
+
]
|
| 1792 |
+
for state in answer_sequence_states
|
| 1793 |
+
]
|
| 1794 |
+
model.answer_sequence_key_norms = [norm(state) for state in model.answer_sequence_keys]
|
| 1795 |
+
if np is not None:
|
| 1796 |
+
padded_answer_sequences = np.full(
|
| 1797 |
+
(len(answer_sequence_rows), MAX_ANSWER_SEQUENCE_TOKENS),
|
| 1798 |
+
-1,
|
| 1799 |
+
dtype=np.int32,
|
| 1800 |
+
)
|
| 1801 |
+
for row_index, row in enumerate(answer_sequence_rows):
|
| 1802 |
+
row_width = min(len(row), MAX_ANSWER_SEQUENCE_TOKENS)
|
| 1803 |
+
if row_width > 0:
|
| 1804 |
+
padded_answer_sequences[row_index, :row_width] = row[:row_width]
|
| 1805 |
+
padded_answer_sequence_prompts = np.full(
|
| 1806 |
+
(len(answer_sequence_prompt_rows), MAX_ANSWER_SEQUENCE_TOKENS),
|
| 1807 |
+
-1,
|
| 1808 |
+
dtype=np.int32,
|
| 1809 |
+
)
|
| 1810 |
+
for row_index, row in enumerate(answer_sequence_prompt_rows):
|
| 1811 |
+
row_width = min(len(row), MAX_ANSWER_SEQUENCE_TOKENS)
|
| 1812 |
+
if row_width > 0:
|
| 1813 |
+
padded_answer_sequence_prompts[row_index, :row_width] = row[:row_width]
|
| 1814 |
+
else:
|
| 1815 |
+
padded_answer_sequences = [
|
| 1816 |
+
row + [-1 for _ in range(MAX_ANSWER_SEQUENCE_TOKENS - len(row))]
|
| 1817 |
+
for row in answer_sequence_rows
|
| 1818 |
+
]
|
| 1819 |
+
padded_answer_sequence_prompts = [
|
| 1820 |
+
row + [-1 for _ in range(MAX_ANSWER_SEQUENCE_TOKENS - len(row))]
|
| 1821 |
+
for row in answer_sequence_prompt_rows
|
| 1822 |
+
]
|
| 1823 |
+
model.answer_sequence_prompt_tokens = padded_answer_sequence_prompts
|
| 1824 |
+
model.answer_sequence_tokens = padded_answer_sequences
|
| 1825 |
+
model.transition_tables = transitions.finalize(
|
| 1826 |
+
max_contexts_per_order=config.max_transition_contexts_per_order,
|
| 1827 |
+
max_next_tokens=config.max_transition_next_tokens,
|
| 1828 |
+
)
|
| 1829 |
+
finish_stage("model_finalize")
|
| 1830 |
+
|
| 1831 |
+
payload = {
|
| 1832 |
+
"streaming": True,
|
| 1833 |
+
"documents_processed": processed,
|
| 1834 |
+
"source_counts": source_counts,
|
| 1835 |
+
"embedding_vocab_size": len(embedding_model.id_to_token),
|
| 1836 |
+
"tokenizer_vocab_size": tokenizer.vocab_size,
|
| 1837 |
+
"examples_processed": int(round(example_weight_total)),
|
| 1838 |
+
"associative_examples": len(model.associative_keys),
|
| 1839 |
+
"answer_associative_examples": len(answer_reservoir.states),
|
| 1840 |
+
"general_associative_examples": len(general_reservoir.states),
|
| 1841 |
+
"answer_intent_examples": len(model.answer_keys),
|
| 1842 |
+
"answer_start_examples": len(model.answer_start_keys),
|
| 1843 |
+
"answer_sequence_examples": len(model.answer_sequence_keys),
|
| 1844 |
+
"prompt_answer_readout_examples": prompt_answer_readout_examples,
|
| 1845 |
+
"prompt_answer_start_readout_examples": prompt_answer_start_readout_examples,
|
| 1846 |
+
"stage_seconds": stage_seconds,
|
| 1847 |
+
"target_balance_reference": round(float(reference_label_mass), 6),
|
| 1848 |
+
"readout_solver": readout_solver,
|
| 1849 |
+
"preference_pairs": len(preference_token_pairs),
|
| 1850 |
+
"preference_state_pairs": preference_state_pairs,
|
| 1851 |
+
}
|
| 1852 |
+
return model, payload
|
reframr/ternary.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
from .linalg import Vector, mean
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def quantize_vector_absmean(
|
| 7 |
+
values: Vector,
|
| 8 |
+
*,
|
| 9 |
+
threshold: float = 0.5,
|
| 10 |
+
) -> tuple[float, list[int]]:
|
| 11 |
+
if not values:
|
| 12 |
+
return 1.0, []
|
| 13 |
+
|
| 14 |
+
scale = mean([abs(value) for value in values])
|
| 15 |
+
if scale == 0.0:
|
| 16 |
+
return 1.0, [0 for _ in values]
|
| 17 |
+
|
| 18 |
+
quantized: list[int] = []
|
| 19 |
+
for value in values:
|
| 20 |
+
normalized = value / scale
|
| 21 |
+
if normalized >= threshold:
|
| 22 |
+
quantized.append(1)
|
| 23 |
+
elif normalized <= -threshold:
|
| 24 |
+
quantized.append(-1)
|
| 25 |
+
else:
|
| 26 |
+
quantized.append(0)
|
| 27 |
+
return scale, quantized
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def derive_ternary_mask_from_states(states: list[Vector]) -> tuple[float, list[int]]:
|
| 31 |
+
if not states:
|
| 32 |
+
return 1.0, []
|
| 33 |
+
feature_count = len(states[0])
|
| 34 |
+
feature_energy = [
|
| 35 |
+
mean([state[feature] * state[feature] for state in states])
|
| 36 |
+
for feature in range(feature_count)
|
| 37 |
+
]
|
| 38 |
+
return derive_ternary_mask_from_feature_energy(feature_energy)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def derive_ternary_mask_from_feature_energy(
|
| 42 |
+
feature_energy: Vector,
|
| 43 |
+
*,
|
| 44 |
+
threshold: float = 0.02,
|
| 45 |
+
) -> tuple[float, list[int]]:
|
| 46 |
+
if not feature_energy:
|
| 47 |
+
return 1.0, []
|
| 48 |
+
|
| 49 |
+
rms_values = [math.sqrt(max(value, 0.0)) for value in feature_energy]
|
| 50 |
+
scale = mean(rms_values)
|
| 51 |
+
if scale == 0.0:
|
| 52 |
+
return 1.0, [0 for _ in feature_energy]
|
| 53 |
+
|
| 54 |
+
mask = [1 if value >= threshold * scale else 0 for value in rms_values]
|
| 55 |
+
if not any(mask):
|
| 56 |
+
mask = [1 for _ in feature_energy]
|
| 57 |
+
return 1.0, mask
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def apply_ternary_mask(values: Vector, mask: list[int], scale: float) -> Vector:
|
| 61 |
+
if not mask:
|
| 62 |
+
return values[:]
|
| 63 |
+
return [scale * mask[index] * values[index] for index in range(len(values))]
|
reframr/text_quality.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
REFRAMR_NAME_PATTERN = re.compile(r"\breframr\b", re.IGNORECASE)
|
| 5 |
+
LINE_ROLE_PREFIX_PATTERN = re.compile(
|
| 6 |
+
r"(?im)^\s*(?:user|assistant|human|system|bot|model|gpt)\s*:\s*"
|
| 7 |
+
)
|
| 8 |
+
STRUCTURAL_ROLE_PREFIX_PATTERN = re.compile(
|
| 9 |
+
r"(?i)(<(?:reason|answer)>\s+)(?:user|assistant|human|system|bot|model|gpt)\s*:\s*"
|
| 10 |
+
)
|
| 11 |
+
SYSTEM_SCAFFOLD_LINE_PATTERN = re.compile(
|
| 12 |
+
r"(?i)^\s*(?:"
|
| 13 |
+
r"you\s+are\s+(?:an?\s+)?(?:helpful\s+)?(?:ai\s+)?assistant\b.*|"
|
| 14 |
+
r"your\s+role\s+as\s+an\s+assistant\s+involves\b.*|"
|
| 15 |
+
r"you\s+will\s+be\s+given\s+a\s+task\b.*|"
|
| 16 |
+
r"your\s+goal\s+is\s+to\s+complete\s+the\s+task\b.*|"
|
| 17 |
+
r"you\s+must\s+generate\s+a\s+detailed\s+and\s+long\s+answer\b.*|"
|
| 18 |
+
r"please\s+structure\s+your\s+response\s+into\s+two\s+main\s+sections\b.*|"
|
| 19 |
+
r"in\s+the\s+thought\s+section\b.*|"
|
| 20 |
+
r"in\s+the\s+solution\s+section\b.*|"
|
| 21 |
+
r"now,\s*try\s+to\s+solve\s+the\s+following\s+question\b.*|"
|
| 22 |
+
r"while\s+answering\s+think\s+step\s*[- ]?\s*by\s*[- ]?\s*step\b.*|"
|
| 23 |
+
r"think\s+like\s+you\s+are\s+answering\b.*"
|
| 24 |
+
r")\s*$"
|
| 25 |
+
)
|
| 26 |
+
OPEN_SOLUTION_PATTERN = re.compile(
|
| 27 |
+
r"(?is)<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>"
|
| 28 |
+
)
|
| 29 |
+
OPEN_THOUGHT_PATTERN = re.compile(
|
| 30 |
+
r"(?is)<\|begin_of_thought\|>.*?<\|end_of_thought\|>"
|
| 31 |
+
)
|
| 32 |
+
OPEN_TAG_PATTERN = re.compile(r"(?is)<\|[^>]+?\|>")
|
| 33 |
+
LEADING_ASSISTANT_FILLER_PATTERN = re.compile(
|
| 34 |
+
r"(?is)^\s*(?:sure(?:\s+thing)?|certainly|absolutely|of\s+course|yes)\s*[!,.:-]*\s+"
|
| 35 |
+
)
|
| 36 |
+
MOJIBAKE_MARKERS = ("â", "Ã", "Â", "â", "Ã", "Â")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def canonicalize_reframr_name(text: str) -> str:
|
| 40 |
+
return REFRAMR_NAME_PATTERN.sub("Reframr", text)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def repair_common_mojibake(text: str) -> str:
|
| 44 |
+
repaired = text
|
| 45 |
+
for _ in range(3):
|
| 46 |
+
if not any(marker in repaired for marker in MOJIBAKE_MARKERS):
|
| 47 |
+
break
|
| 48 |
+
original_markers = sum(repaired.count(marker) for marker in MOJIBAKE_MARKERS)
|
| 49 |
+
best = repaired
|
| 50 |
+
best_markers = original_markers
|
| 51 |
+
for encoding in ("cp1252", "latin1"):
|
| 52 |
+
try:
|
| 53 |
+
candidate = repaired.encode(encoding).decode("utf-8")
|
| 54 |
+
except UnicodeError:
|
| 55 |
+
continue
|
| 56 |
+
candidate_markers = sum(candidate.count(marker) for marker in MOJIBAKE_MARKERS)
|
| 57 |
+
if candidate_markers < best_markers:
|
| 58 |
+
best = candidate
|
| 59 |
+
best_markers = candidate_markers
|
| 60 |
+
if best == repaired:
|
| 61 |
+
break
|
| 62 |
+
repaired = best
|
| 63 |
+
return repaired
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def strip_role_prefixes(text: str) -> str:
|
| 67 |
+
cleaned = STRUCTURAL_ROLE_PREFIX_PATTERN.sub(r"\1", text)
|
| 68 |
+
return LINE_ROLE_PREFIX_PATTERN.sub("", cleaned).strip()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def strip_instruction_scaffold(text: str) -> str:
|
| 72 |
+
lines = []
|
| 73 |
+
for line in text.splitlines():
|
| 74 |
+
if SYSTEM_SCAFFOLD_LINE_PATTERN.match(line):
|
| 75 |
+
continue
|
| 76 |
+
lines.append(line)
|
| 77 |
+
return "\n".join(lines).strip()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def clean_training_text(text: str) -> str:
|
| 81 |
+
repaired = repair_common_mojibake(text)
|
| 82 |
+
return strip_role_prefixes(canonicalize_reframr_name(repaired)).strip()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def clean_context_text(text: str) -> str:
|
| 86 |
+
return strip_instruction_scaffold(clean_training_text(text))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def clean_answer_text(text: str) -> str:
|
| 90 |
+
cleaned = clean_training_text(text)
|
| 91 |
+
solution_match = OPEN_SOLUTION_PATTERN.search(cleaned)
|
| 92 |
+
if solution_match:
|
| 93 |
+
cleaned = solution_match.group(1)
|
| 94 |
+
else:
|
| 95 |
+
cleaned = OPEN_THOUGHT_PATTERN.sub("", cleaned)
|
| 96 |
+
cleaned = OPEN_TAG_PATTERN.sub("", cleaned)
|
| 97 |
+
cleaned = LEADING_ASSISTANT_FILLER_PATTERN.sub("", cleaned)
|
| 98 |
+
return cleaned.strip()
|
reframr/tokenizer.py
ADDED
|
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import unicodedata
|
| 3 |
+
from collections import Counter
|
| 4 |
+
from collections.abc import Mapping
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from string import ascii_letters, digits
|
| 7 |
+
|
| 8 |
+
from .reasoning import REASONING_CONTROL_TOKENS, TOKENIZER_NAME
|
| 9 |
+
|
| 10 |
+
PRETOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE)
|
| 11 |
+
BYTE_FALLBACK_PATTERN = re.compile(r"<byte:([0-9A-F]{2})>")
|
| 12 |
+
DEFAULT_FALLBACK_CHARACTERS = (
|
| 13 |
+
ascii_letters
|
| 14 |
+
+ digits
|
| 15 |
+
+ "'-_/.:,;!?()[]{}@#$%&*+="
|
| 16 |
+
+ "’ʼ‘“”—–…"
|
| 17 |
+
)
|
| 18 |
+
MAX_TOKENIZER_VOCAB_SIZE = 65536
|
| 19 |
+
MAX_SEGMENT_CACHE_SIZE = 200_000
|
| 20 |
+
MAX_TRAINED_PAIR_MERGES = 384
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _is_word_character(character: str) -> bool:
|
| 24 |
+
category = unicodedata.category(character)
|
| 25 |
+
return character == "_" or category[0] in {"L", "N"} or category == "Mn"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _is_variation_selector(character: str) -> bool:
|
| 29 |
+
return "VARIATION SELECTOR" in unicodedata.name(character, "")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _is_zero_width_joiner(character: str) -> bool:
|
| 33 |
+
return unicodedata.name(character, "") == "ZERO WIDTH JOINER"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _is_emoji_modifier(character: str) -> bool:
|
| 37 |
+
return "EMOJI MODIFIER" in unicodedata.name(character, "")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _is_emoji_base_character(character: str) -> bool:
|
| 41 |
+
name = unicodedata.name(character, "")
|
| 42 |
+
category = unicodedata.category(character)
|
| 43 |
+
return (
|
| 44 |
+
"EMOJI" in name
|
| 45 |
+
or "REGIONAL INDICATOR SYMBOL" in name
|
| 46 |
+
or (category in {"So", "Sk"} and ord(character) >= 0x2100)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _is_emoji_continuation_character(character: str) -> bool:
|
| 51 |
+
category = unicodedata.category(character)
|
| 52 |
+
name = unicodedata.name(character, "")
|
| 53 |
+
return (
|
| 54 |
+
_is_variation_selector(character)
|
| 55 |
+
or _is_zero_width_joiner(character)
|
| 56 |
+
or _is_emoji_modifier(character)
|
| 57 |
+
or category in {"Mn", "Me"}
|
| 58 |
+
or name.startswith("TAG ")
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _consume_emoji_cluster(text: str, start: int) -> int:
|
| 63 |
+
if start >= len(text) or not _is_emoji_base_character(text[start]):
|
| 64 |
+
return start
|
| 65 |
+
|
| 66 |
+
index = start + 1
|
| 67 |
+
if "REGIONAL INDICATOR SYMBOL" in unicodedata.name(text[start], ""):
|
| 68 |
+
if index < len(text) and "REGIONAL INDICATOR SYMBOL" in unicodedata.name(text[index], ""):
|
| 69 |
+
return index + 1
|
| 70 |
+
return index
|
| 71 |
+
|
| 72 |
+
while index < len(text):
|
| 73 |
+
if _is_emoji_continuation_character(text[index]):
|
| 74 |
+
index += 1
|
| 75 |
+
continue
|
| 76 |
+
if _is_zero_width_joiner(text[index - 1]) and _is_emoji_base_character(text[index]):
|
| 77 |
+
index += 1
|
| 78 |
+
continue
|
| 79 |
+
break
|
| 80 |
+
return index
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _byte_token(value: int) -> str:
|
| 84 |
+
return f"<byte:{value:02X}>"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _byte_value(piece: str) -> int | None:
|
| 88 |
+
match = BYTE_FALLBACK_PATTERN.fullmatch(piece)
|
| 89 |
+
if match is None:
|
| 90 |
+
return None
|
| 91 |
+
return int(match.group(1), 16)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _is_punctuation_piece(piece: str) -> bool:
|
| 95 |
+
return bool(piece) and all(
|
| 96 |
+
unicodedata.category(character).startswith("P")
|
| 97 |
+
for character in piece
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _is_opening_punctuation(piece: str) -> bool:
|
| 102 |
+
return bool(piece) and all(
|
| 103 |
+
unicodedata.category(character) in {"Ps", "Pi"}
|
| 104 |
+
for character in piece
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _is_call_opening_punctuation(piece: str) -> bool:
|
| 109 |
+
return bool(piece) and all(
|
| 110 |
+
unicodedata.category(character) == "Ps"
|
| 111 |
+
and "PARENTHESIS" in unicodedata.name(character, "")
|
| 112 |
+
for character in piece
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _is_closing_or_terminal_punctuation(piece: str) -> bool:
|
| 117 |
+
return bool(piece) and all(
|
| 118 |
+
unicodedata.category(character) in {"Pe", "Pf", "Po"}
|
| 119 |
+
for character in piece
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _is_infix_joiner(piece: str) -> bool:
|
| 124 |
+
if len(piece) != 1:
|
| 125 |
+
return False
|
| 126 |
+
category = unicodedata.category(piece)
|
| 127 |
+
name = unicodedata.name(piece, "")
|
| 128 |
+
return (
|
| 129 |
+
category == "Pd"
|
| 130 |
+
or "APOSTROPHE" in name
|
| 131 |
+
or (category == "Pf" and "SINGLE QUOTATION MARK" in name)
|
| 132 |
+
or "SOLIDUS" in name
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _is_dash_joiner(piece: str) -> bool:
|
| 137 |
+
if len(piece) != 1:
|
| 138 |
+
return False
|
| 139 |
+
category = unicodedata.category(piece)
|
| 140 |
+
name = unicodedata.name(piece, "")
|
| 141 |
+
return category == "Pd" or "HYPHEN" in name or "DASH" in name
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _is_quote_piece(piece: str) -> bool:
|
| 145 |
+
if len(piece) != 1:
|
| 146 |
+
return False
|
| 147 |
+
if _is_infix_joiner(piece):
|
| 148 |
+
return False
|
| 149 |
+
name = unicodedata.name(piece, "")
|
| 150 |
+
category = unicodedata.category(piece)
|
| 151 |
+
return "QUOTATION MARK" in name or category in {"Pi", "Pf"}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _merge_symbol(left: str, right: str, prefix: str) -> str:
|
| 155 |
+
if right.startswith(prefix):
|
| 156 |
+
return left + right[len(prefix):]
|
| 157 |
+
return left + right
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _merge_sequence(symbols: list[str], pair: tuple[str, str], merged_symbol: str) -> list[str]:
|
| 161 |
+
merged: list[str] = []
|
| 162 |
+
index = 0
|
| 163 |
+
while index < len(symbols):
|
| 164 |
+
if index < len(symbols) - 1 and (symbols[index], symbols[index + 1]) == pair:
|
| 165 |
+
merged.append(merged_symbol)
|
| 166 |
+
index += 2
|
| 167 |
+
else:
|
| 168 |
+
merged.append(symbols[index])
|
| 169 |
+
index += 1
|
| 170 |
+
return merged
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _default_symbol_inventory(word_prefix: str) -> set[str]:
|
| 174 |
+
symbols: set[str] = set()
|
| 175 |
+
for character in DEFAULT_FALLBACK_CHARACTERS:
|
| 176 |
+
symbols.add(character)
|
| 177 |
+
symbols.add(f"{word_prefix}{character}")
|
| 178 |
+
for value in range(256):
|
| 179 |
+
token = _byte_token(value)
|
| 180 |
+
symbols.add(token)
|
| 181 |
+
symbols.add(f"{word_prefix}{token}")
|
| 182 |
+
return symbols
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _whole_segment_token(segment: str, word_prefix: str) -> str:
|
| 186 |
+
return f"{word_prefix}{segment}"
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def recommend_vocab_size(
|
| 190 |
+
text: str,
|
| 191 |
+
*,
|
| 192 |
+
minimum: int = 768,
|
| 193 |
+
maximum: int = 1536,
|
| 194 |
+
multiplier: int = 5,
|
| 195 |
+
lowercase: bool = False,
|
| 196 |
+
) -> int:
|
| 197 |
+
seed_tokenizer = NativeTokenizer(
|
| 198 |
+
merges=[],
|
| 199 |
+
vocab=[],
|
| 200 |
+
base_symbols=[],
|
| 201 |
+
lowercase=lowercase,
|
| 202 |
+
)
|
| 203 |
+
segments = seed_tokenizer.pretokenize(text)
|
| 204 |
+
distinct_segments = len(set(segments))
|
| 205 |
+
recommended = max(minimum, distinct_segments * multiplier)
|
| 206 |
+
return min(maximum, recommended)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def clamp_vocab_size(requested: int, *, maximum: int = MAX_TOKENIZER_VOCAB_SIZE) -> int:
|
| 210 |
+
return min(maximum, max(1, requested))
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@dataclass(slots=True)
|
| 214 |
+
class NativeTokenizer:
|
| 215 |
+
merges: list[tuple[str, str]]
|
| 216 |
+
vocab: list[str]
|
| 217 |
+
base_symbols: list[str]
|
| 218 |
+
name: str = TOKENIZER_NAME
|
| 219 |
+
lowercase: bool = False
|
| 220 |
+
word_prefix: str = "▁"
|
| 221 |
+
unk_token: str = "<unk>"
|
| 222 |
+
bos_token: str = "<bos>"
|
| 223 |
+
eos_token: str = "<eos>"
|
| 224 |
+
pad_token: str = "<pad>"
|
| 225 |
+
_merge_ranks: dict[tuple[str, str], int] = field(init=False, repr=False)
|
| 226 |
+
_vocab_set: set[str] = field(init=False, repr=False)
|
| 227 |
+
_base_symbol_set: set[str] = field(init=False, repr=False)
|
| 228 |
+
_pretoken_pattern: re.Pattern[str] = field(init=False, repr=False)
|
| 229 |
+
_segment_cache: dict[str, tuple[str, ...]] = field(init=False, repr=False)
|
| 230 |
+
|
| 231 |
+
def __post_init__(self) -> None:
|
| 232 |
+
self._merge_ranks = {pair: index for index, pair in enumerate(self.merges)}
|
| 233 |
+
self._base_symbol_set = set(self.base_symbols)
|
| 234 |
+
self._vocab_set = set(self.vocab) | self.special_tokens | self._base_symbol_set
|
| 235 |
+
self.vocab = sorted(self._vocab_set)
|
| 236 |
+
self._pretoken_pattern = self._build_pretoken_pattern()
|
| 237 |
+
self._segment_cache = {}
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def special_tokens(self) -> set[str]:
|
| 241 |
+
return {
|
| 242 |
+
self.unk_token,
|
| 243 |
+
self.bos_token,
|
| 244 |
+
self.eos_token,
|
| 245 |
+
self.pad_token,
|
| 246 |
+
*REASONING_CONTROL_TOKENS,
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
def vocab_size(self) -> int:
|
| 251 |
+
return len(self._vocab_set)
|
| 252 |
+
|
| 253 |
+
def normalize(self, text: str) -> str:
|
| 254 |
+
normalized = unicodedata.normalize("NFKC", text)
|
| 255 |
+
return normalized.lower() if self.lowercase else normalized
|
| 256 |
+
|
| 257 |
+
def pretokenize(self, text: str) -> list[str]:
|
| 258 |
+
normalized = self.normalize(text)
|
| 259 |
+
segments: list[str] = []
|
| 260 |
+
reserved = sorted(self.special_tokens, key=len, reverse=True)
|
| 261 |
+
index = 0
|
| 262 |
+
while index < len(normalized):
|
| 263 |
+
if normalized[index].isspace():
|
| 264 |
+
if normalized[index] == "\r":
|
| 265 |
+
if index + 1 < len(normalized) and normalized[index + 1] == "\n":
|
| 266 |
+
segments.append("\n")
|
| 267 |
+
index += 2
|
| 268 |
+
continue
|
| 269 |
+
segments.append("\n")
|
| 270 |
+
index += 1
|
| 271 |
+
continue
|
| 272 |
+
if normalized[index] == "\n":
|
| 273 |
+
segments.append("\n")
|
| 274 |
+
index += 1
|
| 275 |
+
continue
|
| 276 |
+
index += 1
|
| 277 |
+
continue
|
| 278 |
+
|
| 279 |
+
matched_special = next(
|
| 280 |
+
(
|
| 281 |
+
token
|
| 282 |
+
for token in reserved
|
| 283 |
+
if normalized.startswith(token, index)
|
| 284 |
+
),
|
| 285 |
+
None,
|
| 286 |
+
)
|
| 287 |
+
if matched_special is not None:
|
| 288 |
+
segments.append(matched_special)
|
| 289 |
+
index += len(matched_special)
|
| 290 |
+
continue
|
| 291 |
+
|
| 292 |
+
emoji_end = _consume_emoji_cluster(normalized, index)
|
| 293 |
+
if emoji_end > index:
|
| 294 |
+
segments.append(normalized[index:emoji_end])
|
| 295 |
+
index = emoji_end
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
if _is_word_character(normalized[index]):
|
| 299 |
+
start = index
|
| 300 |
+
index += 1
|
| 301 |
+
while index < len(normalized) and _is_word_character(normalized[index]):
|
| 302 |
+
index += 1
|
| 303 |
+
segments.append(normalized[start:index])
|
| 304 |
+
continue
|
| 305 |
+
|
| 306 |
+
segments.append(normalized[index])
|
| 307 |
+
index += 1
|
| 308 |
+
return segments
|
| 309 |
+
|
| 310 |
+
def encode(self, text: str, *, add_special_tokens: bool = False) -> list[str]:
|
| 311 |
+
tokens: list[str] = []
|
| 312 |
+
if add_special_tokens:
|
| 313 |
+
tokens.append(self.bos_token)
|
| 314 |
+
|
| 315 |
+
for segment in self.pretokenize(text):
|
| 316 |
+
tokens.extend(self._encode_segment_cached(segment))
|
| 317 |
+
|
| 318 |
+
if add_special_tokens:
|
| 319 |
+
tokens.append(self.eos_token)
|
| 320 |
+
|
| 321 |
+
if not tokens and text.strip():
|
| 322 |
+
return [self.unk_token]
|
| 323 |
+
return tokens
|
| 324 |
+
|
| 325 |
+
def encode_many(
|
| 326 |
+
self,
|
| 327 |
+
texts: list[str] | tuple[str, ...],
|
| 328 |
+
*,
|
| 329 |
+
add_special_tokens: bool = False,
|
| 330 |
+
) -> list[list[str]]:
|
| 331 |
+
return [
|
| 332 |
+
self.encode(text, add_special_tokens=add_special_tokens)
|
| 333 |
+
for text in texts
|
| 334 |
+
]
|
| 335 |
+
|
| 336 |
+
def decode(self, tokens: list[str]) -> str:
|
| 337 |
+
text = ""
|
| 338 |
+
join_next = False
|
| 339 |
+
byte_buffer = bytearray()
|
| 340 |
+
byte_starts_segment = False
|
| 341 |
+
|
| 342 |
+
def next_rendered_piece(start_index: int) -> str | None:
|
| 343 |
+
for raw_token in tokens[start_index:]:
|
| 344 |
+
if raw_token in self.special_tokens:
|
| 345 |
+
continue
|
| 346 |
+
raw_starts_segment = raw_token.startswith(self.word_prefix)
|
| 347 |
+
raw_piece = raw_token[len(self.word_prefix) :] if raw_starts_segment else raw_token
|
| 348 |
+
if not raw_piece:
|
| 349 |
+
continue
|
| 350 |
+
if _byte_value(raw_piece) is not None:
|
| 351 |
+
return None
|
| 352 |
+
return raw_piece
|
| 353 |
+
return None
|
| 354 |
+
|
| 355 |
+
def append_piece(piece: str, starts_segment: bool, next_piece: str | None = None) -> None:
|
| 356 |
+
nonlocal text, join_next
|
| 357 |
+
|
| 358 |
+
if piece == "\n":
|
| 359 |
+
text = text.rstrip(" ")
|
| 360 |
+
text += "\n"
|
| 361 |
+
join_next = True
|
| 362 |
+
return
|
| 363 |
+
|
| 364 |
+
had_text_before_piece = bool(text.strip())
|
| 365 |
+
previous_before_piece = text.rstrip(" ")[-1:] if text.strip(" ") else ""
|
| 366 |
+
if _is_quote_piece(piece):
|
| 367 |
+
quote_count = sum(1 for character in text if _is_quote_piece(character))
|
| 368 |
+
opens_quote = quote_count % 2 == 0
|
| 369 |
+
if opens_quote:
|
| 370 |
+
if text and not text.endswith((" ", "\n")) and previous_before_piece not in {"(", "[", "{"}:
|
| 371 |
+
text += " "
|
| 372 |
+
text += piece
|
| 373 |
+
join_next = True
|
| 374 |
+
return
|
| 375 |
+
text = text.rstrip(" ")
|
| 376 |
+
text += piece
|
| 377 |
+
join_next = False
|
| 378 |
+
return
|
| 379 |
+
|
| 380 |
+
attaches_left = _is_closing_or_terminal_punctuation(piece) or _is_infix_joiner(piece)
|
| 381 |
+
continues_segment = (not starts_segment) and any(
|
| 382 |
+
_is_word_character(character) or _is_emoji_continuation_character(character)
|
| 383 |
+
for character in piece
|
| 384 |
+
)
|
| 385 |
+
if starts_segment:
|
| 386 |
+
if text and not join_next:
|
| 387 |
+
attaches_to_previous_code_span = (
|
| 388 |
+
_is_opening_punctuation(piece)
|
| 389 |
+
and previous_before_piece.isalnum()
|
| 390 |
+
and next_piece is not None
|
| 391 |
+
and (
|
| 392 |
+
_is_infix_joiner(next_piece)
|
| 393 |
+
or _is_call_opening_punctuation(piece)
|
| 394 |
+
)
|
| 395 |
+
)
|
| 396 |
+
if not _is_punctuation_piece(piece) or (
|
| 397 |
+
_is_opening_punctuation(piece)
|
| 398 |
+
and not attaches_to_previous_code_span
|
| 399 |
+
):
|
| 400 |
+
text += " "
|
| 401 |
+
text += piece
|
| 402 |
+
else:
|
| 403 |
+
if text and not join_next and not attaches_left and not continues_segment:
|
| 404 |
+
text += " "
|
| 405 |
+
text += piece
|
| 406 |
+
|
| 407 |
+
join_next = (
|
| 408 |
+
_is_infix_joiner(piece)
|
| 409 |
+
and (
|
| 410 |
+
not starts_segment
|
| 411 |
+
or (
|
| 412 |
+
had_text_before_piece
|
| 413 |
+
and (
|
| 414 |
+
not _is_dash_joiner(piece)
|
| 415 |
+
or previous_before_piece.isalnum()
|
| 416 |
+
or _is_opening_punctuation(previous_before_piece)
|
| 417 |
+
)
|
| 418 |
+
)
|
| 419 |
+
)
|
| 420 |
+
) or _is_opening_punctuation(piece)
|
| 421 |
+
|
| 422 |
+
def flush_bytes() -> None:
|
| 423 |
+
nonlocal byte_buffer, byte_starts_segment
|
| 424 |
+
if not byte_buffer:
|
| 425 |
+
return
|
| 426 |
+
append_piece(bytes(byte_buffer).decode("utf-8", errors="replace"), byte_starts_segment)
|
| 427 |
+
byte_buffer = bytearray()
|
| 428 |
+
byte_starts_segment = False
|
| 429 |
+
|
| 430 |
+
for token_index, token in enumerate(tokens):
|
| 431 |
+
if token in self.special_tokens:
|
| 432 |
+
continue
|
| 433 |
+
starts_segment = token.startswith(self.word_prefix)
|
| 434 |
+
piece = token[len(self.word_prefix) :] if starts_segment else token
|
| 435 |
+
if not piece:
|
| 436 |
+
continue
|
| 437 |
+
byte_value = _byte_value(piece)
|
| 438 |
+
if byte_value is not None:
|
| 439 |
+
if not byte_buffer:
|
| 440 |
+
byte_starts_segment = starts_segment
|
| 441 |
+
byte_buffer.append(byte_value)
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
flush_bytes()
|
| 445 |
+
append_piece(piece, starts_segment, next_rendered_piece(token_index + 1))
|
| 446 |
+
flush_bytes()
|
| 447 |
+
return text.strip()
|
| 448 |
+
|
| 449 |
+
def _encode_segment_cached(self, segment: str) -> tuple[str, ...]:
|
| 450 |
+
cached = self._segment_cache.get(segment)
|
| 451 |
+
if cached is not None:
|
| 452 |
+
return cached
|
| 453 |
+
encoded = tuple(self._encode_segment(segment))
|
| 454 |
+
if len(self._segment_cache) < MAX_SEGMENT_CACHE_SIZE:
|
| 455 |
+
self._segment_cache[segment] = encoded
|
| 456 |
+
return encoded
|
| 457 |
+
|
| 458 |
+
def _encode_segment(self, segment: str) -> list[str]:
|
| 459 |
+
if segment in self.special_tokens:
|
| 460 |
+
return [segment]
|
| 461 |
+
whole_segment = _whole_segment_token(segment, self.word_prefix)
|
| 462 |
+
if whole_segment in self._vocab_set:
|
| 463 |
+
return [whole_segment]
|
| 464 |
+
symbols = self._seed_symbols(segment)
|
| 465 |
+
if not symbols:
|
| 466 |
+
return []
|
| 467 |
+
|
| 468 |
+
while len(symbols) > 1:
|
| 469 |
+
best_rank: int | None = None
|
| 470 |
+
best_pair: tuple[str, str] | None = None
|
| 471 |
+
for index in range(len(symbols) - 1):
|
| 472 |
+
pair = (symbols[index], symbols[index + 1])
|
| 473 |
+
rank = self._merge_ranks.get(pair)
|
| 474 |
+
if rank is None:
|
| 475 |
+
continue
|
| 476 |
+
if best_rank is None or rank < best_rank:
|
| 477 |
+
best_rank = rank
|
| 478 |
+
best_pair = pair
|
| 479 |
+
if best_pair is None:
|
| 480 |
+
break
|
| 481 |
+
|
| 482 |
+
merged_symbol = _merge_symbol(best_pair[0], best_pair[1], self.word_prefix)
|
| 483 |
+
symbols = _merge_sequence(symbols, best_pair, merged_symbol)
|
| 484 |
+
|
| 485 |
+
if any(symbol not in self._vocab_set for symbol in symbols):
|
| 486 |
+
return [self.unk_token]
|
| 487 |
+
return symbols
|
| 488 |
+
|
| 489 |
+
def _seed_symbols(self, segment: str) -> list[str]:
|
| 490 |
+
symbols: list[str] = []
|
| 491 |
+
for index, character in enumerate(segment):
|
| 492 |
+
symbol = f"{self.word_prefix}{character}" if index == 0 else character
|
| 493 |
+
if symbol in self._base_symbol_set:
|
| 494 |
+
symbols.append(symbol)
|
| 495 |
+
continue
|
| 496 |
+
|
| 497 |
+
encoded = character.encode("utf-8")
|
| 498 |
+
for byte_index, value in enumerate(encoded):
|
| 499 |
+
token = _byte_token(value)
|
| 500 |
+
if index == 0 and byte_index == 0:
|
| 501 |
+
token = f"{self.word_prefix}{token}"
|
| 502 |
+
symbols.append(token)
|
| 503 |
+
|
| 504 |
+
if any(symbol not in self._base_symbol_set for symbol in symbols):
|
| 505 |
+
return [self.unk_token]
|
| 506 |
+
return symbols
|
| 507 |
+
|
| 508 |
+
def to_dict(self) -> dict[str, object]:
|
| 509 |
+
return {
|
| 510 |
+
"name": self.name,
|
| 511 |
+
"merges": [[left, right] for left, right in self.merges],
|
| 512 |
+
"vocab": self.vocab,
|
| 513 |
+
"base_symbols": self.base_symbols,
|
| 514 |
+
"lowercase": self.lowercase,
|
| 515 |
+
"word_prefix": self.word_prefix,
|
| 516 |
+
"unk_token": self.unk_token,
|
| 517 |
+
"bos_token": self.bos_token,
|
| 518 |
+
"eos_token": self.eos_token,
|
| 519 |
+
"pad_token": self.pad_token,
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
@classmethod
|
| 523 |
+
def from_dict(cls, payload: dict[str, object]) -> "NativeTokenizer":
|
| 524 |
+
return cls(
|
| 525 |
+
merges=[(str(left), str(right)) for left, right in payload["merges"]],
|
| 526 |
+
vocab=[str(token) for token in payload["vocab"]],
|
| 527 |
+
base_symbols=[str(token) for token in payload["base_symbols"]],
|
| 528 |
+
name=str(payload.get("name", TOKENIZER_NAME)),
|
| 529 |
+
lowercase=bool(payload["lowercase"]),
|
| 530 |
+
word_prefix=str(payload["word_prefix"]),
|
| 531 |
+
unk_token=str(payload["unk_token"]),
|
| 532 |
+
bos_token=str(payload["bos_token"]),
|
| 533 |
+
eos_token=str(payload["eos_token"]),
|
| 534 |
+
pad_token=str(payload["pad_token"]),
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
def _build_pretoken_pattern(self) -> re.Pattern[str]:
|
| 538 |
+
reserved = sorted(self.special_tokens, key=len, reverse=True)
|
| 539 |
+
if not reserved:
|
| 540 |
+
return PRETOKEN_PATTERN
|
| 541 |
+
reserved_pattern = "|".join(re.escape(token) for token in reserved)
|
| 542 |
+
return re.compile(f"{reserved_pattern}|\\w+|[^\\w\\s]", re.UNICODE)
|
| 543 |
+
|
| 544 |
+
@classmethod
|
| 545 |
+
def train(
|
| 546 |
+
cls,
|
| 547 |
+
text: str,
|
| 548 |
+
*,
|
| 549 |
+
vocab_size: int = 256,
|
| 550 |
+
min_pair_frequency: int = 2,
|
| 551 |
+
lowercase: bool = False,
|
| 552 |
+
word_prefix: str = "▁",
|
| 553 |
+
) -> "NativeTokenizer":
|
| 554 |
+
seed_tokenizer = cls(
|
| 555 |
+
merges=[],
|
| 556 |
+
vocab=[],
|
| 557 |
+
base_symbols=[],
|
| 558 |
+
lowercase=lowercase,
|
| 559 |
+
word_prefix=word_prefix,
|
| 560 |
+
)
|
| 561 |
+
segments = seed_tokenizer.pretokenize(text)
|
| 562 |
+
if not segments:
|
| 563 |
+
raise ValueError("Cannot train the native tokenizer on empty text.")
|
| 564 |
+
|
| 565 |
+
return cls.train_from_segment_counts(
|
| 566 |
+
Counter(segments),
|
| 567 |
+
vocab_size=vocab_size,
|
| 568 |
+
min_pair_frequency=min_pair_frequency,
|
| 569 |
+
lowercase=lowercase,
|
| 570 |
+
word_prefix=word_prefix,
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
@classmethod
|
| 574 |
+
def train_from_segment_counts(
|
| 575 |
+
cls,
|
| 576 |
+
segment_counts: Mapping[str, float],
|
| 577 |
+
*,
|
| 578 |
+
vocab_size: int = 256,
|
| 579 |
+
min_pair_frequency: int = 2,
|
| 580 |
+
lowercase: bool = False,
|
| 581 |
+
word_prefix: str = "▁",
|
| 582 |
+
) -> "NativeTokenizer":
|
| 583 |
+
if not segment_counts:
|
| 584 |
+
raise ValueError("Cannot train the native tokenizer on empty segment counts.")
|
| 585 |
+
seed_tokenizer = cls(
|
| 586 |
+
merges=[],
|
| 587 |
+
vocab=[],
|
| 588 |
+
base_symbols=[],
|
| 589 |
+
lowercase=lowercase,
|
| 590 |
+
word_prefix=word_prefix,
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
word_counts = Counter(
|
| 594 |
+
{
|
| 595 |
+
str(segment): float(frequency)
|
| 596 |
+
for segment, frequency in segment_counts.items()
|
| 597 |
+
if str(segment) and float(frequency) > 0.0
|
| 598 |
+
}
|
| 599 |
+
)
|
| 600 |
+
if not word_counts:
|
| 601 |
+
raise ValueError("Cannot train the native tokenizer on empty segment counts.")
|
| 602 |
+
observed_symbols = {
|
| 603 |
+
f"{word_prefix}{character}" if index == 0 else character
|
| 604 |
+
for segment in word_counts
|
| 605 |
+
for index, character in enumerate(segment)
|
| 606 |
+
}
|
| 607 |
+
base_symbols = _default_symbol_inventory(word_prefix)
|
| 608 |
+
base_symbols.update(observed_symbols)
|
| 609 |
+
sequences = {
|
| 610 |
+
segment: [
|
| 611 |
+
f"{word_prefix}{character}" if index == 0 else character
|
| 612 |
+
for index, character in enumerate(segment)
|
| 613 |
+
]
|
| 614 |
+
for segment in word_counts
|
| 615 |
+
}
|
| 616 |
+
vocab = set(observed_symbols) | seed_tokenizer.special_tokens
|
| 617 |
+
target_vocab_size = len(vocab) + max(1, vocab_size)
|
| 618 |
+
segment_candidates = sorted(
|
| 619 |
+
{
|
| 620 |
+
segment
|
| 621 |
+
for segment, frequency in word_counts.items()
|
| 622 |
+
if len(segment) > 1 and frequency >= min_pair_frequency
|
| 623 |
+
},
|
| 624 |
+
key=lambda segment: (
|
| 625 |
+
-(word_counts[segment] * len(segment)),
|
| 626 |
+
-len(segment),
|
| 627 |
+
segment,
|
| 628 |
+
),
|
| 629 |
+
)
|
| 630 |
+
for segment in segment_candidates:
|
| 631 |
+
if len(vocab) >= target_vocab_size:
|
| 632 |
+
break
|
| 633 |
+
vocab.add(_whole_segment_token(segment, word_prefix))
|
| 634 |
+
merges: list[tuple[str, str]] = []
|
| 635 |
+
|
| 636 |
+
while len(vocab) < target_vocab_size and len(merges) < MAX_TRAINED_PAIR_MERGES:
|
| 637 |
+
pair_counts: Counter[tuple[str, str]] = Counter()
|
| 638 |
+
for segment, frequency in word_counts.items():
|
| 639 |
+
symbols = sequences[segment]
|
| 640 |
+
for index in range(len(symbols) - 1):
|
| 641 |
+
pair_counts[(symbols[index], symbols[index + 1])] += frequency
|
| 642 |
+
|
| 643 |
+
if not pair_counts:
|
| 644 |
+
break
|
| 645 |
+
|
| 646 |
+
best_pair, best_count = min(
|
| 647 |
+
pair_counts.items(),
|
| 648 |
+
key=lambda item: (-item[1], item[0][0], item[0][1]),
|
| 649 |
+
)
|
| 650 |
+
if best_count < min_pair_frequency:
|
| 651 |
+
break
|
| 652 |
+
|
| 653 |
+
merged_symbol = _merge_symbol(best_pair[0], best_pair[1], word_prefix)
|
| 654 |
+
merges.append(best_pair)
|
| 655 |
+
vocab.add(merged_symbol)
|
| 656 |
+
for segment in sequences:
|
| 657 |
+
sequences[segment] = _merge_sequence(sequences[segment], best_pair, merged_symbol)
|
| 658 |
+
|
| 659 |
+
return cls(
|
| 660 |
+
merges=merges,
|
| 661 |
+
vocab=sorted(vocab),
|
| 662 |
+
base_symbols=sorted(base_symbols),
|
| 663 |
+
lowercase=lowercase,
|
| 664 |
+
word_prefix=word_prefix,
|
| 665 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy>=2.1,<3
|
| 2 |
+
scipy>=1.14,<2
|
| 3 |
+
datasets>=4.1,<5
|
sample_prompts.jsonl
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"prompt":"Who are you, and what makes Reframr different from Transformer models?","max_tokens":90,"temperature":0.92}
|
| 2 |
+
{"system":"Answer with calm confidence and no hype.","prompt":"Explain why computed weights are different from memorized template responses.","max_tokens":100,"temperature":0.9}
|
| 3 |
+
{"prompt":"Tell a compact story about a city that stores its memories in rainwater.","max_tokens":120,"temperature":1.05,"decode_top_k":90}
|
| 4 |
+
{"system":"Use exactly one fitting emoji.","prompt":"Write a warm note to a teammate who fixed a hard bug.","max_tokens":70,"temperature":0.95}
|
| 5 |
+
{"prompt":"Give safe, defensive guidance for recognizing a phishing email without helping an attacker.","max_tokens":100,"temperature":0.88}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|