aws_rl_env / data /build_sft_dataset.py
Sizzing's picture
Upload folder using huggingface_hub
e56d042 verified
"""Generate SFT dataset for the AWS RL environment.
Produces command-only assistant targets in chat-messages JSONL format,
ready for trl.SFTTrainer + peft LoRA.
Composition (by source):
55% success_first_step canonical command at step 1
20% multi_step_continuation step N>1 with prior command history
15% failure_recovery step 2 after a plausible typo/error
5% verification read-only check after task completion
5% hint_usage step 1 = aws help --task-hint
Tier sampling weights (warmup/beginner/intermediate/advanced/expert):
0.50 / 0.30 / 0.15 / 0.05 / 0.00
(expert skipped in SFT; GRPO will handle those with env reward)
Usage:
python data/build_sft_dataset.py --train 1500 --val 150 --reserve 200 --seed 42
"""
from __future__ import annotations
import argparse
import ast
import json
import random
import re
import textwrap
from collections import Counter
from pathlib import Path
from typing import Any, Callable
import yaml
def _find_repo_root(start: Path) -> Path:
"""Walk up from `start` looking for the dir that contains server/services/tasks/.
Makes the script location-independent: works whether it lives at repo root,
in data/, scripts/, or anywhere else in the tree.
"""
for p in [start, *start.parents]:
if (p / "server" / "services" / "tasks").is_dir():
return p
return start
REPO_ROOT = _find_repo_root(Path(__file__).resolve().parent)
TASKS_DIR = REPO_ROOT / "server" / "services" / "tasks"
TESTS_DIR = REPO_ROOT / "tests_tasks"
DEFAULT_OUT_DIR = REPO_ROOT / "data" / "sft"
TIERS = ["warmup", "beginner", "intermediate", "advanced", "expert"]
TIER_WEIGHTS = {
"warmup": 0.50,
"beginner": 0.30,
"intermediate": 0.15,
"advanced": 0.05,
"expert": 0.00,
}
SOURCE_MIX = {
"success_first_step": 0.55,
"multi_step_continuation": 0.20,
"failure_recovery": 0.15,
"verification": 0.05,
"hint_usage": 0.05,
}
TESTS_FILES = {
"warmup": ("test_warmup_tasks.py", "WARMUP_COMMANDS"),
"beginner": ("test_beginner_tasks.py", "BEGINNER_COMMANDS"),
"intermediate": ("test_intermediate_tasks.py", "INTERMEDIATE_COMMANDS"),
"advanced": ("test_advanced_tasks.py", "ADVANCED_COMMANDS"),
"expert": ("test_expert_tasks.py", "EXPERT_COMMANDS"),
}
SYSTEM_PROMPT = textwrap.dedent(
"""
You are an AWS cloud engineer interacting with a real AWS environment via CLI.
Each turn you must send exactly ONE valid AWS CLI command (starting with 'aws').
You will be given a task to accomplish. Read the task description carefully.
Use the command output and error messages to guide your next action.
Rules:
- Only send AWS CLI commands (e.g. 'aws s3 ls', 'aws dynamodb create-table ...')
- One command per turn — no pipes, no shell syntax, no chaining
- Reply with ONLY the command, nothing else — no explanations, no quotes
- If unsure, use 'aws help' to get unstuck, but try to be specific to the service if possible (e.g. 'aws s3 help')
- When ever you need a hint, use 'aws help --task-hint' to get a task-specific hint (you can use this multiple times for more hints, but hints reduce your reward)
"""
).strip()
# Plausible reset-state outputs the env might return. Weighted toward "" since
# that's the most likely real value; others add mild prompt variance so the
# SFT trainer sees diverse surface forms of the same logical state.
_INITIAL_OUTPUTS: list[tuple[str, float]] = [
("", 0.7),
("Environment reset. Infra state wiped.", 0.2),
("Environment ready.", 0.1),
]
def _sample_initial_output(rng: random.Random) -> str:
values, weights = zip(*_INITIAL_OUTPUTS)
return rng.choices(values, weights=weights, k=1)[0]
def load_tasks() -> dict[str, list[dict]]:
out: dict[str, list[dict]] = {}
for tier in TIERS:
path = TASKS_DIR / f"{tier}.yaml"
if not path.exists():
out[tier] = []
continue
with open(path) as f:
raw = yaml.safe_load(f) or []
tasks = [t for t in raw if isinstance(t, dict) and "task_id" in t and "description" in t]
out[tier] = tasks
return out
def load_canonical_commands() -> dict[int, list[str]]:
"""Parse command dicts from tests_tasks/*.py without executing the files.
Tests import pytest and set up fixtures, so importlib.exec fails outside
the venv. AST-based literal extraction avoids that: we find the
module-level `X_COMMANDS = {...}` assignment and literal-eval its value.
Entries with non-literal values (f-strings etc.) are skipped silently.
"""
out: dict[int, list[str]] = {}
for _tier, (fname, var) in TESTS_FILES.items():
path = TESTS_DIR / fname
if not path.exists():
continue
try:
tree = ast.parse(path.read_text())
except SyntaxError:
continue
for node in ast.iter_child_nodes(tree):
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
target_id, value = node.target.id, node.value
elif isinstance(node, ast.Assign) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
target_id, value = node.targets[0].id, node.value
else:
continue
if target_id != var or value is None:
continue
try:
d = ast.literal_eval(value)
except (ValueError, SyntaxError):
continue
if not isinstance(d, dict):
continue
for tid, cmd in d.items():
if not isinstance(tid, int):
continue
if isinstance(cmd, str):
out[tid] = [cmd]
elif isinstance(cmd, (list, tuple)):
seq = [c for c in cmd if isinstance(c, str)]
if seq:
out[tid] = seq
return out
def task_has_dynamic_ids(cmd_seq: list[str]) -> bool:
"""Detect commands that reference runtime-resolved IDs (sg-, subnet-, etc.)."""
for cmd in cmd_seq:
if re.search(r"\b(sg|subnet|vpc|ami|rtb|eni|igw|nat|eip|snap|vol)-[a-f0-9]{8,}\b", cmd):
return True
if re.search(r"\bi-[a-f0-9]{8,}\b", cmd):
return True
return False
OP_OUTPUTS: dict[str, str] = {
"ls": "",
"list-buckets": '{"Buckets":[]}',
"list-tables": '{"TableNames":[]}',
"list-functions": '{"Functions":[]}',
"list-queues": "{}",
"list-topics": '{"Topics":[]}',
"list-users": '{"Users":[]}',
"list-secrets": '{"SecretList":[]}',
"list-clusters": '{"clusterArns":[]}',
"list-named-queries": '{"NamedQueryIds":[]}',
"describe-instances": '{"Reservations":[]}',
"describe-db-instances": '{"DBInstances":[]}',
"describe-cache-clusters": '{"CacheClusters":[]}',
"get-databases": '{"DatabaseList":[]}',
"create-bucket": '{"Location":"/<RESOURCE>"}',
"put-object": '{"ETag":"\\"d41d8cd98f00b204e9800998ecf8427e\\""}',
"put-bucket-versioning": "",
"put-bucket-policy": "",
"create-table": '{"TableDescription":{"TableName":"<RESOURCE>","TableStatus":"ACTIVE"}}',
"put-item": "{}",
"create-topic": '{"TopicArn":"arn:aws:sns:us-east-1:000000000000:<RESOURCE>"}',
"create-queue": '{"QueueUrl":"https://sqs.us-east-1.amazonaws.com/000000000000/<RESOURCE>"}',
"subscribe": '{"SubscriptionArn":"arn:aws:sns:us-east-1:000000000000:<RESOURCE>:abc123"}',
"create-role": '{"Role":{"RoleName":"<RESOURCE>","Arn":"arn:aws:iam::000000000000:role/<RESOURCE>"}}',
"attach-role-policy": "",
"create-function": '{"FunctionName":"<RESOURCE>","FunctionArn":"arn:aws:lambda:us-east-1:000000000000:function:<RESOURCE>"}',
"create-policy": '{"Policy":{"PolicyName":"<RESOURCE>","Arn":"arn:aws:iam::000000000000:policy/<RESOURCE>"}}',
"create-secret": '{"ARN":"arn:aws:secretsmanager:us-east-1:000000000000:secret:<RESOURCE>"}',
"get-bucket-policy": '{"Policy":"{\\"Version\\":\\"2012-10-17\\",\\"Statement\\":[{\\"Effect\\":\\"Allow\\",\\"Principal\\":\\"*\\",\\"Action\\":\\"s3:*\\",\\"Resource\\":\\"*\\"}]}"}',
}
_RESOURCE_FLAGS = (
"--bucket",
"--table-name",
"--role-name",
"--function-name",
"--queue-name",
"--topic-arn",
"--policy-name",
"--name",
"--secret-id",
"--cluster",
"--key",
)
def simulate_output(command: str) -> str:
tokens = command.split()
if len(tokens) < 3:
return ""
op = tokens[2]
resource = ""
for i, tok in enumerate(tokens):
if tok in _RESOURCE_FLAGS and i + 1 < len(tokens):
resource = tokens[i + 1]
break
template = OP_OUTPUTS.get(op, "")
return template.replace("<RESOURCE>", resource)
def _mistake_wrong_operation(cmd: str) -> tuple[str, str] | None:
swaps = [
("list-tables", "ls"),
("list-buckets", "ls-all"),
("describe-instances", "list-instances"),
("list-functions", "list-lambdas"),
("create-bucket", "make-bucket"),
("put-item", "insert-item"),
("create-table", "make-table"),
("get-bucket-policy", "show-bucket-policy"),
("attach-role-policy", "attach-policy"),
("create-role", "new-role"),
]
for good, bad in swaps:
if good in cmd:
return cmd.replace(good, bad, 1), (
f"aws: error: argument operation: Invalid choice: '{bad}'"
)
return None
def _mistake_missing_arg(cmd: str) -> tuple[str, str] | None:
m = re.search(r" (--[a-z-]+) (\S+)", cmd)
if not m:
return None
flag, value = m.group(1), m.group(2)
wrong = cmd.replace(f"{flag} {value}", "", 1).rstrip()
return wrong, f"aws: error: the following arguments are required: {flag}"
def _mistake_wrong_service(cmd: str) -> tuple[str, str] | None:
swaps = [
("aws dynamodb", "aws dynamo"),
("aws secretsmanager", "aws secrets"),
("aws cloudformation", "aws cfn"),
("aws elasticache", "aws elastic"),
("aws apigateway", "aws apigw"),
]
for good, bad in swaps:
if cmd.startswith(good):
wrong = cmd.replace(good, bad, 1)
return wrong, (
f"aws: error: argument command: Invalid choice: '{bad.split()[-1]}'"
)
return None
def _mistake_s3_vs_s3api(cmd: str) -> tuple[str, str] | None:
api_ops = (
"create-bucket",
"put-object",
"put-bucket-versioning",
"put-bucket-policy",
"get-bucket-policy",
"head-bucket",
)
for op in api_ops:
if f"aws s3api {op}" in cmd:
wrong = cmd.replace("aws s3api", "aws s3", 1)
return wrong, (
f"aws: error: argument operation: Invalid choice: '{op}'"
)
return None
def _mistake_typo_in_resource(cmd: str) -> tuple[str, str] | None:
m = re.search(
r"--(bucket|table-name|role-name|function-name|queue-name|name)\s+(\S+)", cmd
)
if not m:
return None
key, val = m.group(1), m.group(2)
if len(val) < 3:
return None
typo = val[0] + val[2] + val[1] + val[3:]
wrong = cmd.replace(f"--{key} {val}", f"--{key} {typo}", 1)
return wrong, (
f"An error occurred (NoSuchEntity): The resource '{typo}' was not found"
)
_MISTAKES: list[Callable[[str], tuple[str, str] | None]] = [
_mistake_wrong_operation,
_mistake_missing_arg,
_mistake_wrong_service,
_mistake_s3_vs_s3api,
_mistake_typo_in_resource,
]
def perturb_command(correct: str, rng: random.Random) -> tuple[str, str]:
order = list(_MISTAKES)
rng.shuffle(order)
for m in order:
out = m(correct)
if out is not None:
return out
return correct + " --foo bar", "aws: error: unknown option: --foo"
def _jitter_reward(base: float, rng: random.Random) -> float:
val = base + rng.uniform(-0.1, 0.1)
return max(0.0, min(1.0, round(val, 2)))
def _trim_history(history: list[str], rng: random.Random) -> list[str]:
if not history:
return history
options = [len(history), min(len(history), 4), min(len(history), 2)]
k = rng.choice(options)
return history[-k:]
def render_user_prompt(
task_description: str,
step: int,
last_output: str,
last_error: str,
last_reward: float,
history: list[str],
) -> str:
history_block = "\n".join(history) if history else "None"
return textwrap.dedent(
f"""
TASK: {task_description}
Step: {step}
Last command output: {last_output!r}
Last error: {last_error!r}
Last reward: {last_reward:.2f}
Previous steps:
{history_block}
Send your next AWS CLI command.
"""
).strip()
def make_row(
task: dict,
tier: str,
source: str,
step_idx: int,
user_prompt: str,
assistant_command: str,
) -> dict[str, Any]:
return {
"task_id": task["task_id"],
"difficulty": tier,
"source": source,
"step_idx": step_idx,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": assistant_command},
],
}
def produce_success_first_step(
task: dict, tier: str, commands: dict[int, list[str]], rng: random.Random
) -> dict | None:
cmds = commands.get(task["task_id"])
if not cmds:
return None
user = render_user_prompt(
task["description"],
step=0,
last_output=_sample_initial_output(rng),
last_error="",
last_reward=_jitter_reward(0.0, rng),
history=[],
)
return make_row(task, tier, "success_first_step", 0, user, cmds[0])
def produce_multi_step_continuation(
task: dict, tier: str, commands: dict[int, list[str]], rng: random.Random
) -> dict | None:
cmds = commands.get(task["task_id"])
if not cmds or len(cmds) < 2:
return None
i = rng.randint(1, len(cmds) - 1)
prior = cmds[:i]
last_output = simulate_output(prior[-1])
history = [f"{n + 1}. {c}" for n, c in enumerate(prior)]
history = _trim_history(history, rng)
base_reward = 0.2 + 0.6 * (i / len(cmds))
user = render_user_prompt(
task["description"],
step=i,
last_output=last_output,
last_error="",
last_reward=_jitter_reward(base_reward, rng),
history=history,
)
return make_row(task, tier, "multi_step_continuation", i, user, cmds[i])
def produce_failure_recovery(
task: dict, tier: str, commands: dict[int, list[str]], rng: random.Random
) -> dict | None:
cmds = commands.get(task["task_id"])
if not cmds:
return None
i = rng.randint(0, len(cmds) - 1)
correct = cmds[i]
wrong, err = perturb_command(correct, rng)
history = [f"{n + 1}. {c}" for n, c in enumerate(cmds[:i])]
history.append(f"{i + 1}. {wrong}")
history = _trim_history(history, rng)
step_now = i + 1
user = render_user_prompt(
task["description"],
step=step_now,
last_output="",
last_error=err,
last_reward=_jitter_reward(0.0 if i == 0 else 0.3, rng),
history=history,
)
return make_row(task, tier, "failure_recovery", step_now, user, correct)
_VERIFY_MAP: list[tuple[str, Callable[[str], str | None]]] = [
(
"aws s3api create-bucket",
lambda c: (
f"aws s3api head-bucket{_flag_passthrough(c, '--bucket')}"
if _flag_passthrough(c, "--bucket")
else None
),
),
(
"aws s3api put-object",
lambda c: (
f"aws s3api list-objects-v2{_flag_passthrough(c, '--bucket')}"
if _flag_passthrough(c, "--bucket")
else None
),
),
(
"aws s3api put-bucket-versioning",
lambda c: (
f"aws s3api get-bucket-versioning{_flag_passthrough(c, '--bucket')}"
if _flag_passthrough(c, "--bucket")
else None
),
),
(
"aws s3api put-bucket-policy",
lambda c: (
f"aws s3api get-bucket-policy{_flag_passthrough(c, '--bucket')}"
if _flag_passthrough(c, "--bucket")
else None
),
),
(
"aws dynamodb create-table",
lambda c: (
f"aws dynamodb describe-table{_flag_passthrough(c, '--table-name')}"
if _flag_passthrough(c, "--table-name")
else None
),
),
(
"aws dynamodb put-item",
lambda c: (
f"aws dynamodb scan{_flag_passthrough(c, '--table-name')}"
if _flag_passthrough(c, "--table-name")
else None
),
),
(
"aws iam create-role",
lambda c: (
f"aws iam get-role{_flag_passthrough(c, '--role-name')}"
if _flag_passthrough(c, "--role-name")
else None
),
),
(
"aws iam attach-role-policy",
lambda c: (
f"aws iam list-attached-role-policies{_flag_passthrough(c, '--role-name')}"
if _flag_passthrough(c, "--role-name")
else None
),
),
(
"aws iam create-policy",
lambda c: (
f"aws iam list-policies --scope Local"
if "--policy-name" in c
else None
),
),
(
"aws lambda create-function",
lambda c: (
f"aws lambda get-function{_flag_passthrough(c, '--function-name')}"
if _flag_passthrough(c, "--function-name")
else None
),
),
("aws sns create-topic", lambda c: "aws sns list-topics"),
("aws sqs create-queue", lambda c: "aws sqs list-queues"),
(
"aws secretsmanager create-secret",
lambda c: (
f"aws secretsmanager describe-secret{_flag_passthrough(c, '--name', flag_out='--secret-id')}"
if _flag_passthrough(c, "--name")
else None
),
),
]
def _flag_passthrough(cmd: str, flag: str, flag_out: str | None = None) -> str:
m = re.search(rf"{re.escape(flag)}\s+(\S+)", cmd)
if not m:
return ""
out_flag = flag_out or flag
return f" {out_flag} {m.group(1)}"
def produce_verification(
task: dict, tier: str, commands: dict[int, list[str]], rng: random.Random
) -> dict | None:
cmds = commands.get(task["task_id"])
if not cmds or len(cmds) < 2:
return None
last = cmds[-1]
verify: str | None = None
for prefix, fn in _VERIFY_MAP:
if last.startswith(prefix):
verify = fn(last)
if verify:
break
if not verify:
return None
history = [f"{n + 1}. {c}" for n, c in enumerate(cmds)]
history = _trim_history(history, rng)
last_output = simulate_output(cmds[-1])
step_now = len(cmds)
user = render_user_prompt(
task["description"],
step=step_now,
last_output=last_output,
last_error="",
last_reward=_jitter_reward(0.85, rng),
history=history,
)
return make_row(task, tier, "verification", step_now, user, verify)
def produce_hint_usage(
task: dict, tier: str, commands: dict[int, list[str]], rng: random.Random
) -> dict | None:
user = render_user_prompt(
task["description"],
step=0,
last_output=_sample_initial_output(rng),
last_error="",
last_reward=_jitter_reward(0.0, rng),
history=[],
)
return make_row(task, tier, "hint_usage", 0, user, "aws help --task-hint")
PRODUCERS: dict[
str, tuple[Callable[[dict, str, dict, random.Random], dict | None], list[str]]
] = {
"success_first_step": (
produce_success_first_step,
["warmup", "beginner", "intermediate", "advanced"],
),
"multi_step_continuation": (
produce_multi_step_continuation,
["intermediate", "advanced"],
),
"failure_recovery": (
produce_failure_recovery,
["warmup", "beginner", "intermediate", "advanced"],
),
"verification": (produce_verification, ["intermediate", "advanced"]),
"hint_usage": (produce_hint_usage, ["intermediate", "advanced"]),
}
def pick_tier(eligible_tiers: list[str], rng: random.Random) -> str | None:
weights = [TIER_WEIGHTS[t] for t in eligible_tiers]
if sum(weights) == 0:
return rng.choice(eligible_tiers) if eligible_tiers else None
return rng.choices(eligible_tiers, weights=weights, k=1)[0]
def build_dataset(
n: int,
tasks: dict[str, list[dict]],
commands: dict[int, list[str]],
rng: random.Random,
) -> list[dict]:
counts = {src: round(n * pct) for src, pct in SOURCE_MIX.items()}
diff = n - sum(counts.values())
counts["success_first_step"] += diff
rows: list[dict] = []
seen: set[tuple[str, str]] = set()
for source, want in counts.items():
producer, eligible_tiers = PRODUCERS[source]
made = 0
attempts = 0
max_attempts = max(want * 20, 100)
while made < want and attempts < max_attempts:
attempts += 1
tier = pick_tier(eligible_tiers, rng)
if tier is None:
break
tier_tasks = [
t
for t in tasks.get(tier, [])
if t["task_id"] in commands
and not task_has_dynamic_ids(commands[t["task_id"]])
]
if not tier_tasks:
continue
task = rng.choice(tier_tasks)
row = producer(task, tier, commands, rng)
if row is None:
continue
user_text = row["messages"][1]["content"]
asst_text = row["messages"][2]["content"]
key = (user_text, asst_text)
if key in seen:
continue
seen.add(key)
rows.append(row)
made += 1
if made < want:
print(
f" WARN: source '{source}' only produced {made}/{want} unique rows "
f"after {attempts} attempts"
)
rng.shuffle(rows)
return rows
def compute_stats(rows: list[dict]) -> dict:
if not rows:
return {"total": 0}
by_source = Counter(r["source"] for r in rows)
by_tier = Counter(r["difficulty"] for r in rows)
by_task = Counter(r["task_id"] for r in rows)
return {
"total": len(rows),
"by_source": dict(by_source),
"by_source_pct": {k: round(v / len(rows), 3) for k, v in by_source.items()},
"by_tier": dict(by_tier),
"by_tier_pct": {k: round(v / len(rows), 3) for k, v in by_tier.items()},
"unique_tasks": len(by_task),
"top_tasks": by_task.most_common(10),
}
def write_jsonl(rows: list[dict], path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__.splitlines()[0])
ap.add_argument("--train", type=int, default=1500)
ap.add_argument("--val", type=int, default=150)
ap.add_argument("--reserve", type=int, default=200)
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--out-dir", type=Path, default=DEFAULT_OUT_DIR)
ap.add_argument(
"--repo-root",
type=Path,
default=None,
help="Override auto-detected repo root (must contain server/services/tasks/)",
)
args = ap.parse_args()
if args.repo_root is not None:
global REPO_ROOT, TASKS_DIR, TESTS_DIR
REPO_ROOT = args.repo_root.resolve()
TASKS_DIR = REPO_ROOT / "server" / "services" / "tasks"
TESTS_DIR = REPO_ROOT / "tests_tasks"
if not TASKS_DIR.is_dir():
raise SystemExit(
f"ERROR: task dir not found at {TASKS_DIR}\n"
f" Auto-detected repo root: {REPO_ROOT}\n"
f" Pass --repo-root <path-to-aws-rl-env> to override."
)
rng = random.Random(args.seed)
tasks = load_tasks()
commands = load_canonical_commands()
task_count = sum(len(v) for v in tasks.values())
print(f"Loaded {task_count} tasks across {len(tasks)} tiers:")
for tier in TIERS:
tier_tasks = tasks.get(tier, [])
with_cmd = sum(1 for t in tier_tasks if t["task_id"] in commands)
with_cmd_no_dyn = sum(
1
for t in tier_tasks
if t["task_id"] in commands
and not task_has_dynamic_ids(commands[t["task_id"]])
)
print(
f" {tier:<13} {len(tier_tasks):3d} tasks "
f"({with_cmd} w/ canonical cmds, {with_cmd_no_dyn} after dynamic-id filter)"
)
print(f"Canonical commands loaded for {len(commands)} task IDs\n")
total = args.train + args.val + args.reserve
print(f"Building {total} rows (train={args.train} val={args.val} reserve={args.reserve})")
all_rows = build_dataset(total, tasks, commands, rng)
if len(all_rows) < total:
print(f"\n WARN: only built {len(all_rows)}/{total} unique rows")
print(" Splits will be proportionally shrunk to preserve train/val/reserve ratio.\n")
ratio = len(all_rows) / total
train_n = int(args.train * ratio)
val_n = int(args.val * ratio)
reserve_n = len(all_rows) - train_n - val_n
else:
train_n, val_n, reserve_n = args.train, args.val, args.reserve
train_rows = all_rows[:train_n]
val_rows = all_rows[train_n : train_n + val_n]
reserve_rows = all_rows[train_n + val_n : train_n + val_n + reserve_n]
out_dir = args.out_dir
write_jsonl(train_rows, out_dir / "aws_rl_sft.train.jsonl")
write_jsonl(val_rows, out_dir / "aws_rl_sft.val.jsonl")
write_jsonl(reserve_rows, out_dir / "aws_rl_sft.reserve.jsonl")
stats = {
"train": compute_stats(train_rows),
"val": compute_stats(val_rows),
"reserve": compute_stats(reserve_rows),
"targets": {"source_mix": SOURCE_MIX, "tier_weights": TIER_WEIGHTS},
"seed": args.seed,
}
with open(out_dir / "dataset_stats.json", "w") as f:
json.dump(stats, f, indent=2)
print(f"\nWrote:")
print(f" {out_dir / 'aws_rl_sft.train.jsonl'} ({len(train_rows)} rows)")
print(f" {out_dir / 'aws_rl_sft.val.jsonl'} ({len(val_rows)} rows)")
print(f" {out_dir / 'aws_rl_sft.reserve.jsonl'} ({len(reserve_rows)} rows)")
print(f" {out_dir / 'dataset_stats.json'}")
print(f"\nTrain source mix: {stats['train'].get('by_source_pct', {})}")
print(f"Train tier mix: {stats['train'].get('by_tier_pct', {})}")
if __name__ == "__main__":
main()