OkeyMeta's picture
Release Reframr-RFM-v1-Base public checkpoint
2147ce8 verified
import argparse
import json
import sys
from pathlib import Path
from .checkpoint import inspect_checkpoint
from .config import ReframrConfig
from .corpus_recipes import (
build_foundation_corpus,
build_generalization_corpus,
write_corpus_package,
)
from .curriculum import CurriculumConfig, write_curriculum_package
from .datasets import load_prompt_suite, load_text_corpus
from .evaluation import benchmark_open_prompts, evaluate_manifest, load_manifest
from .hf_import import import_hf_dataset
from .model import ReframrModel
from .reasoning import REASONING_PROFILES, TOKENIZER_NAME, reasoning_prefix
from .streaming import fit_model_from_corpus_plan, load_corpus_plan
from .tokenizer import MAX_TOKENIZER_VOCAB_SIZE, clamp_vocab_size, recommend_vocab_size
def configure_stdio() -> None:
for stream in (sys.stdout, sys.stderr):
reconfigure = getattr(stream, "reconfigure", None)
if reconfigure is not None:
reconfigure(encoding="utf-8")
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="reframr",
description="Compute and query REFRAMR analytical language model checkpoints.",
)
subparsers = parser.add_subparsers(dest="command", required=True)
compute = subparsers.add_parser(
"compute",
aliases=["train"],
help="Compute a REFRAMR checkpoint from a text corpus with no epoch loop.",
)
compute.add_argument(
"--input",
required=True,
help="Path to a text, JSON, or JSONL corpus file, or a directory of such files.",
)
compute.add_argument("--output", required=True, help="Path to write the .safetensors checkpoint.")
compute.add_argument("--embedding-dim", type=int, default=16)
compute.add_argument("--state-dim", type=int, default=32)
compute.add_argument("--timescales", default="1.0,0.5,0.25,0.125")
compute.add_argument("--window-size", type=int, default=2)
compute.add_argument("--regularization", type=float, default=1e-3)
compute.add_argument("--min-frequency", type=int, default=1)
compute.add_argument(
"--max-vocab",
type=int,
default=256,
help="Cap analytical embedding vocabulary to keep weight computation fast on CPU.",
)
compute.add_argument("--tokenizer-vocab-size", type=int, default=0)
compute.add_argument("--tokenizer-min-pair-frequency", type=int, default=2)
compute.add_argument(
"--max-training-examples",
type=int,
default=60000,
help="Cap sampled recurrent training states while still reading the full corpus for tokenizer, embeddings, and transitions.",
)
compute.add_argument(
"--max-transition-contexts",
type=int,
default=4096,
help="Keep only the strongest learned transition contexts per order. Use 0 to disable the cap.",
)
compute.add_argument(
"--max-transition-next-tokens",
type=int,
default=4,
help="Keep this many learned next-token choices per transition context.",
)
case_group = compute.add_mutually_exclusive_group()
case_group.add_argument(
"--lowercase",
action="store_true",
help="Normalize corpus text to lowercase before tokenization.",
)
case_group.add_argument("--preserve-case", action="store_true", help=argparse.SUPPRESS)
compute.add_argument(
"--reasoning-profile",
choices=sorted(REASONING_PROFILES),
default="none",
help="Default reasoning-control profile baked into the checkpoint.",
)
recompute = subparsers.add_parser(
"recompute",
help="Compute a REFRAMR checkpoint from a streaming corpus plan with no raw-text cache.",
)
recompute.add_argument("--plan", required=True, help="Path to a streaming corpus plan JSON file.")
recompute.add_argument("--output", required=True, help="Path to write the .safetensors checkpoint.")
recompute.add_argument("--embedding-dim", type=int, default=16)
recompute.add_argument("--state-dim", type=int, default=32)
recompute.add_argument("--timescales", default="1.0,0.5,0.25,0.125")
recompute.add_argument("--window-size", type=int, default=2)
recompute.add_argument("--regularization", type=float, default=1e-3)
recompute.add_argument("--min-frequency", type=int, default=1)
recompute.add_argument("--max-vocab", type=int, default=256)
recompute.add_argument("--tokenizer-vocab-size", type=int, default=0)
recompute.add_argument("--tokenizer-min-pair-frequency", type=int, default=2)
recompute.add_argument("--max-training-examples", type=int, default=60000)
recompute.add_argument("--max-transition-contexts", type=int, default=4096)
recompute.add_argument("--max-transition-next-tokens", type=int, default=4)
recompute.add_argument("--log-every", type=int, default=0)
recompute_case_group = recompute.add_mutually_exclusive_group()
recompute_case_group.add_argument("--lowercase", action="store_true")
recompute_case_group.add_argument("--preserve-case", action="store_true", help=argparse.SUPPRESS)
recompute.add_argument(
"--reasoning-profile",
choices=sorted(REASONING_PROFILES),
default="none",
help="Default reasoning-control profile baked into the checkpoint.",
)
predict = subparsers.add_parser("predict", help="Predict the next-token distribution from a saved model.")
predict.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
predict.add_argument("--context", required=True, help="Input context text.")
predict.add_argument("--top-k", type=int, default=5)
predict.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile.",
)
generate = subparsers.add_parser("generate", help="Generate long-form text from a saved model.")
generate.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
generate.add_argument("--context", required=True, help="Prompt or starting context text.")
generate.add_argument("--system", default="", help="Optional system instruction to prepend as learned context.")
generate.add_argument("--max-tokens", type=int, default=64)
generate.add_argument("--temperature", type=float, default=0.82)
generate.add_argument("--decode-top-k", type=int, default=24)
generate.add_argument("--decode-top-p", type=float, default=0.92)
generate.add_argument("--repetition-penalty", type=float, default=1.18)
generate.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile.",
)
generate_batch = subparsers.add_parser(
"generate-batch",
help="Generate answers for a prompt file while keeping one checkpoint loaded.",
)
generate_batch.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
generate_batch.add_argument("--prompts", required=True, help="Path to a TXT, JSON, or JSONL prompt suite.")
generate_batch.add_argument("--output", required=True, help="Path to write JSONL generations.")
generate_batch.add_argument("--max-tokens", type=int, default=64)
generate_batch.add_argument("--temperature", type=float, default=0.82)
generate_batch.add_argument("--decode-top-k", type=int, default=24)
generate_batch.add_argument("--decode-top-p", type=float, default=0.92)
generate_batch.add_argument("--repetition-penalty", type=float, default=1.18)
generate_batch.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile.",
)
serve = subparsers.add_parser(
"serve",
help="Keep one checkpoint loaded and answer JSONL generation requests from stdin.",
)
serve.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
serve.add_argument("--max-tokens", type=int, default=64)
serve.add_argument("--temperature", type=float, default=0.82)
serve.add_argument("--decode-top-k", type=int, default=24)
serve.add_argument("--decode-top-p", type=float, default=0.92)
serve.add_argument("--repetition-penalty", type=float, default=1.18)
serve.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile.",
)
trace = subparsers.add_parser("trace", help="Trace REFRAMR reasoning components through generation steps.")
trace.add_argument("--model", required=True, help="Path to a serialized REFRAMR model.")
trace.add_argument("--context", required=True, help="Prompt or starting context text.")
trace.add_argument("--max-tokens", type=int, default=8)
trace.add_argument("--top-k", type=int, default=5)
trace.add_argument("--temperature", type=float, default=0.82)
trace.add_argument("--decode-top-p", type=float, default=0.92)
trace.add_argument("--repetition-penalty", type=float, default=1.18)
trace.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile.",
)
inspect = subparsers.add_parser("inspect", help="Inspect a REFRAMR safetensors checkpoint.")
inspect.add_argument("--model", required=True, help="Path to a .safetensors checkpoint.")
craft = subparsers.add_parser(
"craft-corpus",
help="Generate a JSON-first bootstrap corpus, manifest, and generalization prompt suite.",
)
craft.add_argument("--output-dir", required=True, help="Directory to write corpus and manifest files.")
craft.add_argument(
"--variant",
choices=("foundation", "generalization"),
default="foundation",
help="Choose between the mixed foundation corpus and the language-first generalization corpus.",
)
craft_curriculum = subparsers.add_parser(
"craft-curriculum",
help="Generate the OkeyMeta JSON curriculum shard, manifest, holdout prompts, and recompute plan.",
)
craft_curriculum.add_argument("--output-dir", required=True, help="Directory to write curriculum files.")
craft_curriculum.add_argument(
"--records-per-category",
type=int,
default=1000,
help="How many JSON records to generate for each curriculum category.",
)
craft_curriculum.add_argument("--seed", type=int, default=7)
craft_curriculum.add_argument("--train-ratio", type=float, default=0.92)
craft_curriculum.add_argument(
"--effective-token-target",
type=int,
default=0,
help="Set plan weighting so compact curriculum statistics represent this many effective tokens.",
)
evaluate = subparsers.add_parser(
"evaluate",
help="Evaluate memorization and held-out generalization from a benchmark manifest.",
)
evaluate.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.")
evaluate.add_argument("--manifest", required=True, help="Path to a corpus benchmark manifest JSON file.")
evaluate.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile during evaluation.",
)
evaluate.add_argument("--top-k", type=int, default=5)
benchmark_open = subparsers.add_parser(
"benchmark-open",
help="Run arbitrary prompt files through a checkpoint with open-ended output metrics.",
)
benchmark_open.add_argument("--model", required=True, help="Path to a REFRAMR .safetensors checkpoint.")
benchmark_open.add_argument("--prompts", required=True, help="Path to a TXT, JSON, or JSONL prompt suite.")
benchmark_open.add_argument("--max-tokens", type=int, default=64)
benchmark_open.add_argument("--temperature", type=float, default=0.82)
benchmark_open.add_argument("--decode-top-k", type=int, default=24)
benchmark_open.add_argument("--decode-top-p", type=float, default=0.92)
benchmark_open.add_argument("--repetition-penalty", type=float, default=1.18)
benchmark_open.add_argument(
"--reasoning-mode",
choices=sorted(REASONING_PROFILES),
default=None,
help="Override the checkpoint's default reasoning-control profile during benchmarking.",
)
import_hf = subparsers.add_parser(
"import-hf",
help="Import Hugging Face dataset text into the REFRAMR JSON record standard.",
)
import_hf.add_argument("--dataset", required=True, help="Hugging Face dataset id.")
import_hf.add_argument("--output", required=True, help="Path to write the JSONL corpus.")
import_hf.add_argument("--config", default=None, help="Optional dataset config/subset.")
import_hf.add_argument("--split", default="train", help="Dataset split to import.")
import_hf.add_argument("--text-field", default=None, help="Explicit text column name.")
import_hf.add_argument("--limit", type=int, default=1000, help="Maximum records to import.")
import_hf.add_argument(
"--min-words",
type=int,
default=0,
help="Drop imported records shorter than this many words.",
)
import_hf.add_argument(
"--max-words",
type=int,
default=0,
help="Drop imported records longer than this many words. Use 0 to disable.",
)
import_hf.add_argument(
"--min-alpha-ratio",
type=float,
default=0.0,
help="Drop imported records whose alphabetic-character ratio falls below this threshold.",
)
import_hf.add_argument(
"--allowed-languages",
default="",
help="Optional comma-separated language codes to keep, such as en,yo,ig,ha.",
)
import_hf.add_argument(
"--preference-target",
choices=("both", "chosen", "rejected"),
default="chosen",
help="When importing preference datasets, keep both sides or only the chosen/rejected side.",
)
import_hf.add_argument(
"--no-streaming",
action="store_true",
help="Disable streaming dataset reads.",
)
return parser
def parse_timescales(raw_timescales: str) -> tuple[float, ...]:
values = [segment.strip() for segment in raw_timescales.split(",") if segment.strip()]
if not values:
raise ValueError("At least one timescale is required.")
return tuple(float(value) for value in values)
def command_compute(args: argparse.Namespace) -> int:
text = load_text_corpus(args.input)
requested_vocab_size = args.tokenizer_vocab_size or recommend_vocab_size(
text,
lowercase=args.lowercase,
)
tokenizer_vocab_size = clamp_vocab_size(requested_vocab_size)
config = ReframrConfig(
embedding_dim=args.embedding_dim,
state_dim=args.state_dim,
timescales=parse_timescales(args.timescales),
window_size=args.window_size,
regularization=args.regularization,
min_frequency=args.min_frequency,
max_vocab=args.max_vocab,
tokenizer_vocab_size=tokenizer_vocab_size,
tokenizer_min_pair_frequency=args.tokenizer_min_pair_frequency,
max_training_examples=args.max_training_examples,
max_transition_contexts_per_order=(
args.max_transition_contexts if args.max_transition_contexts > 0 else None
),
max_transition_next_tokens=args.max_transition_next_tokens,
lowercase=args.lowercase,
default_reasoning_profile=args.reasoning_profile,
)
model = ReframrModel(config).fit(text)
model.save(args.output)
assert model.tokenizer is not None
assert model.embedding_model is not None
summary = {
"status": "computed",
"format": "safetensors",
"model_path": str(Path(args.output).resolve()),
"tokenizer_name": TOKENIZER_NAME,
"vocab_size": len(model.embedding_model.id_to_token),
"tokenizer_vocab_budget": config.tokenizer_vocab_size,
"tokenizer_vocab_budget_max": MAX_TOKENIZER_VOCAB_SIZE,
"tokenizer_vocab_size": model.tokenizer.vocab_size,
"reasoning_profile": config.default_reasoning_profile,
"reasoning_tokens": reasoning_prefix(config.default_reasoning_profile),
"lowercase": config.lowercase,
"max_training_examples": config.max_training_examples,
"max_transition_contexts_per_order": config.max_transition_contexts_per_order,
"max_transition_next_tokens": config.max_transition_next_tokens,
"embedding_dim": config.embedding_dim,
"state_dim": config.state_dim,
"timescales": list(config.timescales),
}
print(json.dumps(summary))
return 0
def command_recompute(args: argparse.Namespace) -> int:
plan = load_corpus_plan(args.plan)
requested_vocab_size = args.tokenizer_vocab_size or 1024
tokenizer_vocab_size = clamp_vocab_size(requested_vocab_size)
config = ReframrConfig(
embedding_dim=args.embedding_dim,
state_dim=args.state_dim,
timescales=parse_timescales(args.timescales),
window_size=args.window_size,
regularization=args.regularization,
min_frequency=args.min_frequency,
max_vocab=args.max_vocab,
tokenizer_vocab_size=tokenizer_vocab_size,
tokenizer_min_pair_frequency=args.tokenizer_min_pair_frequency,
max_training_examples=args.max_training_examples,
max_transition_contexts_per_order=(
args.max_transition_contexts if args.max_transition_contexts > 0 else None
),
max_transition_next_tokens=args.max_transition_next_tokens,
lowercase=args.lowercase,
default_reasoning_profile=args.reasoning_profile,
)
model, payload = fit_model_from_corpus_plan(
plan,
config,
log_every=args.log_every,
)
model.save(args.output)
summary = {
"status": "recomputed",
"format": "safetensors",
"streaming": True,
"plan_path": str(Path(args.plan).resolve()),
"model_path": str(Path(args.output).resolve()),
"tokenizer_name": TOKENIZER_NAME,
"tokenizer_vocab_budget": config.tokenizer_vocab_size,
"tokenizer_vocab_budget_max": MAX_TOKENIZER_VOCAB_SIZE,
"tokenizer_vocab_size": payload["tokenizer_vocab_size"],
"vocab_size": payload["embedding_vocab_size"],
"documents_processed": payload["documents_processed"],
"source_counts": payload["source_counts"],
"examples_processed": payload["examples_processed"],
"associative_examples": payload["associative_examples"],
"answer_associative_examples": payload.get("answer_associative_examples", 0),
"general_associative_examples": payload.get("general_associative_examples", 0),
"answer_intent_examples": payload.get("answer_intent_examples", 0),
"answer_start_examples": payload.get("answer_start_examples", 0),
"answer_sequence_examples": payload.get("answer_sequence_examples", 0),
"prompt_answer_readout_examples": payload.get("prompt_answer_readout_examples", 0),
"prompt_answer_start_readout_examples": payload.get("prompt_answer_start_readout_examples", 0),
"preference_pairs": payload.get("preference_pairs", 0),
"preference_state_pairs": payload.get("preference_state_pairs", 0),
"stage_seconds": payload.get("stage_seconds", {}),
"readout_solver": payload.get("readout_solver"),
"reasoning_profile": config.default_reasoning_profile,
"reasoning_tokens": reasoning_prefix(config.default_reasoning_profile),
"lowercase": config.lowercase,
"max_training_examples": config.max_training_examples,
"max_transition_contexts_per_order": config.max_transition_contexts_per_order,
"max_transition_next_tokens": config.max_transition_next_tokens,
"embedding_dim": config.embedding_dim,
"state_dim": config.state_dim,
"timescales": list(config.timescales),
}
print(json.dumps(summary))
return 0
def command_predict(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
distribution = model.predict_next_distribution(
args.context,
reasoning_mode=args.reasoning_mode,
)
predictions = sorted(
distribution.items(),
key=lambda item: item[1],
reverse=True,
)[: args.top_k]
payload = {
"context": args.context,
"reasoning_mode": args.reasoning_mode or model.config.default_reasoning_profile,
"reasoning_tokens": reasoning_prefix(args.reasoning_mode or model.config.default_reasoning_profile),
"predictions": [
{"token": token, "probability": probability}
for token, probability in predictions
],
}
print(json.dumps(payload))
return 0
def command_generate(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
context = compose_generation_context(args.context, system=args.system)
generated_text = model.generate_text(
context,
max_tokens=args.max_tokens,
reasoning_mode=args.reasoning_mode,
temperature=args.temperature,
top_k=args.decode_top_k,
top_p=args.decode_top_p,
repetition_penalty=args.repetition_penalty,
)
payload = {
"context": context,
"reasoning_mode": args.reasoning_mode or model.config.default_reasoning_profile,
"reasoning_tokens": reasoning_prefix(args.reasoning_mode or model.config.default_reasoning_profile),
"generated_token_count": len(generated_text.split()),
"generated_text": generated_text,
}
print(json.dumps(payload))
return 0
def compose_generation_context(prompt: str, *, system: str = "") -> str:
clean_prompt = prompt.strip()
clean_system = system.strip()
if not clean_system:
return clean_prompt
return f"System instruction: {clean_system}\nUser: {clean_prompt}"
def command_generate_batch(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
prompts = load_prompt_suite(args.prompts)
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
active_mode = args.reasoning_mode or model.config.default_reasoning_profile
rows: list[dict[str, object]] = []
with output_path.open("w", encoding="utf-8") as handle:
for index, record in enumerate(prompts):
prompt = str(record["prompt"])
context = compose_generation_context(
prompt,
system=str(record.get("system", "")),
)
max_tokens = int(record.get("max_tokens", args.max_tokens))
generated_text = model.generate_text(
context,
max_tokens=max_tokens,
reasoning_mode=args.reasoning_mode,
temperature=args.temperature,
top_k=args.decode_top_k,
top_p=args.decode_top_p,
repetition_penalty=args.repetition_penalty,
)
row = {
"index": index,
"prompt": prompt,
"context": context,
"system": record.get("system", ""),
"tags": record.get("tags", []),
"reasoning_mode": active_mode,
"reasoning_tokens": reasoning_prefix(active_mode),
"generated_token_count": len(generated_text.split()),
"generated_text": generated_text,
}
rows.append(row)
handle.write(json.dumps(row, ensure_ascii=False, separators=(",", ":")) + "\n")
payload = {
"status": "generated",
"sample_count": len(rows),
"model_path": str(Path(args.model).resolve()),
"prompts_path": str(Path(args.prompts).resolve()),
"output_path": str(output_path.resolve()),
"model_loads": 1,
}
print(json.dumps(payload))
return 0
def command_serve(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
default_mode = args.reasoning_mode or model.config.default_reasoning_profile
for index, raw_line in enumerate(sys.stdin):
line = raw_line.strip()
if not line:
continue
try:
request = json.loads(line)
except json.JSONDecodeError as exc:
response = {
"index": index,
"error": "invalid_json",
"message": str(exc),
"model_loads": 1,
}
sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
sys.stdout.flush()
continue
if isinstance(request, str):
context = request
request_payload: dict[str, object] = {}
elif isinstance(request, dict):
request_payload = request
raw_context = str(request_payload.get("prompt", request_payload.get("context", "")))
context = compose_generation_context(
raw_context,
system=str(request_payload.get("system", "")),
)
else:
response = {
"index": index,
"error": "invalid_request",
"message": "request must be a JSON object or string",
"model_loads": 1,
}
sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
sys.stdout.flush()
continue
active_mode = str(request_payload.get("reasoning_mode", default_mode))
max_tokens = int(request_payload.get("max_tokens", args.max_tokens))
temperature = float(request_payload.get("temperature", args.temperature))
top_k = int(request_payload.get("decode_top_k", args.decode_top_k))
top_p = float(request_payload.get("decode_top_p", args.decode_top_p))
repetition_penalty = float(
request_payload.get("repetition_penalty", args.repetition_penalty)
)
generated_text = model.generate_text(
context,
max_tokens=max_tokens,
reasoning_mode=active_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
response = {
"index": index,
"context": context,
"reasoning_mode": active_mode,
"reasoning_tokens": reasoning_prefix(active_mode),
"generated_token_count": len(generated_text.split()),
"generated_text": generated_text,
"model_loads": 1,
}
sys.stdout.write(json.dumps(response, ensure_ascii=False, separators=(",", ":")) + "\n")
sys.stdout.flush()
return 0
def command_trace(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
payload = model.trace_generation(
args.context,
max_tokens=args.max_tokens,
reasoning_mode=args.reasoning_mode,
top_k=args.top_k,
temperature=args.temperature,
top_p=args.decode_top_p,
repetition_penalty=args.repetition_penalty,
)
print(json.dumps(payload))
return 0
def command_inspect(args: argparse.Namespace) -> int:
print(json.dumps(inspect_checkpoint(args.model)))
return 0
def command_craft_corpus(args: argparse.Namespace) -> int:
package = (
build_generalization_corpus()
if args.variant == "generalization"
else build_foundation_corpus()
)
paths = write_corpus_package(package, args.output_dir)
payload = {
"name": package.name,
"corpus_path": paths["corpus_path"],
"manifest_path": paths["manifest_path"],
"prompt_suite_path": paths["prompt_suite_path"],
"token_count_estimate": len(package.text.split()),
"memorization_samples": len(package.memorization_samples),
"generalization_samples": len(package.generalization_samples),
"generalization_prompt_count": len(package.open_ended_samples),
"variant": args.variant,
"section_counts": package.section_counts,
}
print(json.dumps(payload))
return 0
def command_craft_curriculum(args: argparse.Namespace) -> int:
payload = write_curriculum_package(
args.output_dir,
CurriculumConfig(
records_per_category=args.records_per_category,
seed=args.seed,
train_ratio=args.train_ratio,
),
effective_token_target=args.effective_token_target or None,
)
print(json.dumps(payload))
return 0
def command_evaluate(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
manifest = load_manifest(args.manifest)
payload = evaluate_manifest(
model,
manifest,
reasoning_mode=args.reasoning_mode,
top_k=args.top_k,
)
print(json.dumps(payload))
return 0
def command_benchmark_open(args: argparse.Namespace) -> int:
model = ReframrModel.load(args.model)
prompts = load_prompt_suite(args.prompts)
payload = benchmark_open_prompts(
model,
prompts,
reasoning_mode=args.reasoning_mode,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.decode_top_k,
top_p=args.decode_top_p,
repetition_penalty=args.repetition_penalty,
)
print(json.dumps(payload))
return 0
def command_import_hf(args: argparse.Namespace) -> int:
payload = import_hf_dataset(
dataset=args.dataset,
output_path=args.output,
config=args.config,
split=args.split,
text_field=args.text_field,
limit=args.limit,
streaming=not args.no_streaming,
preference_target=args.preference_target,
min_words=args.min_words,
max_words=args.max_words,
min_alpha_ratio=args.min_alpha_ratio,
allowed_languages=tuple(
segment.strip()
for segment in args.allowed_languages.split(",")
if segment.strip()
),
)
print(json.dumps(payload))
return 0
def main(argv: list[str] | None = None) -> int:
configure_stdio()
parser = build_parser()
args = parser.parse_args(argv)
if args.command in {"compute", "train"}:
return command_compute(args)
if args.command == "recompute":
return command_recompute(args)
if args.command == "predict":
return command_predict(args)
if args.command == "generate":
return command_generate(args)
if args.command == "generate-batch":
return command_generate_batch(args)
if args.command == "serve":
return command_serve(args)
if args.command == "trace":
return command_trace(args)
if args.command == "inspect":
return command_inspect(args)
if args.command == "craft-corpus":
return command_craft_corpus(args)
if args.command == "craft-curriculum":
return command_craft_curriculum(args)
if args.command == "evaluate":
return command_evaluate(args)
if args.command == "benchmark-open":
return command_benchmark_open(args)
if args.command == "import-hf":
return command_import_hf(args)
parser.error(f"Unknown command: {args.command}")
return 2