| | |
| | """Multi-stage curriculum SFT for advancing the conjecture math model.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import gc |
| | import json |
| | import os |
| | import subprocess |
| | import sys |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | import torch |
| | import yaml |
| | from datasets import Dataset, DatasetDict, load_dataset |
| | from huggingface_hub import HfApi |
| | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
| | from torch.utils.data import WeightedRandomSampler |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | BitsAndBytesConfig, |
| | DataCollatorForSeq2Seq, |
| | Trainer, |
| | TrainingArguments, |
| | set_seed, |
| | ) |
| |
|
| | SCRIPT_ROOT = Path(__file__).resolve().parents[1] |
| | DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml" |
| | DEFAULT_EVAL_SCRIPT = Path(__file__).resolve().with_name("eval_sota.py") |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser( |
| | description="Train DeepSeek-Math with a multi-stage SOTA curriculum recipe." |
| | ) |
| | parser.add_argument( |
| | "--config", |
| | type=Path, |
| | default=DEFAULT_CONFIG_PATH, |
| | help="Path to multi-stage YAML config.", |
| | ) |
| | parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.") |
| | parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.") |
| | parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.") |
| | parser.add_argument( |
| | "--run-post-eval", |
| | action="store_true", |
| | help="Force post-training evaluation enabled.", |
| | ) |
| | parser.add_argument( |
| | "--no-post-eval", |
| | action="store_true", |
| | help="Force post-training evaluation disabled.", |
| | ) |
| | parser.add_argument( |
| | "--skip-quality-gate", |
| | action="store_true", |
| | help="Disable quality gate checks for this run.", |
| | ) |
| | parser.add_argument( |
| | "--start-stage", |
| | type=int, |
| | default=1, |
| | help="1-based stage index to start from.", |
| | ) |
| | parser.add_argument( |
| | "--max-stages", |
| | type=int, |
| | default=None, |
| | help="Optional number of stages to run from --start-stage.", |
| | ) |
| | parser.add_argument( |
| | "--credentials-path", |
| | type=Path, |
| | default=None, |
| | help="Override credentials.path.", |
| | ) |
| | parser.add_argument( |
| | "--dry-run", |
| | action="store_true", |
| | help="Validate data/filter/tokenization stages without running training or pushing.", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def as_text(value: Any) -> str: |
| | if value is None: |
| | return "" |
| | if isinstance(value, str): |
| | return value.strip() |
| | return str(value).strip() |
| |
|
| |
|
| | def as_float(value: Any, default: float) -> float: |
| | if value is None: |
| | return default |
| | try: |
| | return float(value) |
| | except (TypeError, ValueError): |
| | return default |
| |
|
| |
|
| | def as_int(value: Any, default: int) -> int: |
| | if value is None: |
| | return default |
| | try: |
| | return int(value) |
| | except (TypeError, ValueError): |
| | return default |
| |
|
| |
|
| | def as_bool(value: Any, default: bool = False) -> bool: |
| | if value is None: |
| | return default |
| | if isinstance(value, bool): |
| | return value |
| | text = as_text(value).lower() |
| | if text in {"1", "true", "yes", "y", "on"}: |
| | return True |
| | if text in {"0", "false", "no", "n", "off"}: |
| | return False |
| | return default |
| |
|
| |
|
| | def load_config(path: Path) -> Dict[str, Any]: |
| | if not path.exists(): |
| | raise FileNotFoundError(f"Config not found: {path}") |
| | cfg = yaml.safe_load(path.read_text(encoding="utf-8")) |
| | if not isinstance(cfg, dict): |
| | raise ValueError(f"Invalid config format: {path}") |
| | for key in ("model", "data", "stages"): |
| | if key not in cfg: |
| | raise ValueError(f"Missing config section: {key}") |
| | if not isinstance(cfg["stages"], list) or not cfg["stages"]: |
| | raise ValueError("Config must contain at least one stage in stages[].") |
| | cfg.setdefault("global", {}) |
| | cfg.setdefault("training_defaults", {}) |
| | cfg.setdefault("hub", {}) |
| | cfg.setdefault("credentials", {}) |
| | cfg.setdefault("post_eval", {}) |
| | cfg.setdefault("quality_gate", {}) |
| | return cfg |
| |
|
| |
|
| | def apply_overrides(cfg: Dict[str, Any], args: argparse.Namespace) -> None: |
| | if args.repo_id: |
| | cfg.setdefault("hub", {})["repo_id"] = args.repo_id |
| | if args.credentials_path is not None: |
| | cfg.setdefault("credentials", {})["path"] = str(args.credentials_path) |
| | if args.push_to_hub and args.no_push_to_hub: |
| | raise ValueError("Cannot set both --push-to-hub and --no-push-to-hub.") |
| | if args.push_to_hub: |
| | cfg.setdefault("hub", {})["push_to_hub"] = True |
| | if args.no_push_to_hub: |
| | cfg.setdefault("hub", {})["push_to_hub"] = False |
| |
|
| | if args.run_post_eval and args.no_post_eval: |
| | raise ValueError("Cannot set both --run-post-eval and --no-post-eval.") |
| | if args.run_post_eval: |
| | cfg.setdefault("post_eval", {})["enabled"] = True |
| | if args.no_post_eval: |
| | cfg.setdefault("post_eval", {})["enabled"] = False |
| |
|
| | if args.skip_quality_gate: |
| | cfg.setdefault("quality_gate", {})["enabled"] = False |
| |
|
| |
|
| | def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]: |
| | token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None |
| | username = as_text(os.environ.get("HF_USERNAME")) or None |
| | cred_path = as_text(cfg.get("credentials", {}).get("path")) |
| | if cred_path: |
| | path = Path(cred_path) |
| | if path.exists(): |
| | data = json.loads(path.read_text(encoding="utf-8")) |
| | if token is None: |
| | for key in ("token", "key", "api_key", "hf_token"): |
| | candidate = as_text(data.get(key)) |
| | if candidate: |
| | token = candidate |
| | break |
| | if username is None: |
| | for key in ("username", "user", "owner"): |
| | candidate = as_text(data.get(key)) |
| | if candidate: |
| | username = candidate |
| | break |
| | return token, username |
| |
|
| |
|
| | def resolve_repo_id(cfg: Dict[str, Any], username: Optional[str], output_root: Path) -> Optional[str]: |
| | repo_id = as_text(cfg.get("hub", {}).get("repo_id")) |
| | if repo_id: |
| | return repo_id |
| | if not username: |
| | return None |
| | return f"{username}/{output_root.name}" |
| |
|
| |
|
| | def stringify_structured(value: Any) -> str: |
| | if value is None: |
| | return "" |
| | if isinstance(value, str): |
| | text = value.strip() |
| | if not text: |
| | return "" |
| | try: |
| | parsed = json.loads(text) |
| | except json.JSONDecodeError: |
| | return text |
| | return json.dumps(parsed, ensure_ascii=False, sort_keys=True) |
| | return json.dumps(value, ensure_ascii=False, sort_keys=True) |
| |
|
| |
|
| | def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: |
| | prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt" |
| | prompt = as_text(row.get(prompt_field)) |
| | if not prompt: |
| | prompt = "Solve the math task." |
| | meta_fields = [ |
| | ("task_type", "Task type"), |
| | ("family", "Family"), |
| | ("difficulty", "Difficulty"), |
| | ("source_dataset", "Source"), |
| | ("status_as_of", "Status as of"), |
| | ] |
| | meta_lines = [] |
| | for key, label in meta_fields: |
| | value = as_text(row.get(key)) |
| | if value: |
| | meta_lines.append(f"{label}: {value}") |
| | tags = row.get("topic_tags") |
| | if isinstance(tags, list) and tags: |
| | tag_text = ", ".join(as_text(tag) for tag in tags if as_text(tag)) |
| | if tag_text: |
| | meta_lines.append(f"Tags: {tag_text}") |
| | if not meta_lines: |
| | return prompt |
| | return f"{prompt}\n\nMetadata:\n" + "\n".join(meta_lines) |
| |
|
| |
|
| | def build_answer_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: |
| | target_field = as_text(data_cfg.get("target_field")) or "target" |
| | final_answer_field = as_text(data_cfg.get("final_answer_field")) or "final_answer" |
| | proof_field = as_text(data_cfg.get("proof_field")) or "proof_formal" |
| |
|
| | sections = [] |
| | target_text = stringify_structured(row.get(target_field)) |
| | if target_text: |
| | sections.append(f"Structured target:\n{target_text}") |
| |
|
| | final_answer = stringify_structured(row.get(final_answer_field)) |
| | if final_answer: |
| | sections.append(f"Final answer:\n{final_answer}") |
| |
|
| | proof_text = stringify_structured(row.get(proof_field)) |
| | if proof_text: |
| | sections.append(f"Formal proof snippet:\n{proof_text}") |
| |
|
| | if not sections: |
| | sections.append("No structured target provided.") |
| | return "\n\n".join(sections).strip() |
| |
|
| |
|
| | def build_prompt_text(row: Dict[str, Any], tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> str: |
| | system_prompt = as_text(data_cfg.get("system_prompt")) |
| | if not system_prompt: |
| | system_prompt = ( |
| | "You are a rigorous mathematical reasoning assistant focused on unsolved " |
| | "conjectures. Produce checkable reasoning." |
| | ) |
| | user_block = build_user_block(row, data_cfg) |
| | if getattr(tokenizer, "chat_template", None): |
| | messages = [ |
| | {"role": "system", "content": system_prompt}, |
| | {"role": "user", "content": user_block}, |
| | ] |
| | return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | return f"System:\n{system_prompt}\n\nUser:\n{user_block}\n\nAssistant:\n" |
| |
|
| |
|
| | def compute_loss_weight(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> float: |
| | sample_weight_field = as_text(data_cfg.get("sample_weight_field")) or "sample_weight" |
| | base = as_float(row.get(sample_weight_field), 1.0) |
| | family = as_text(row.get("family")) |
| | family_boost = data_cfg.get("family_boost", {}) |
| | if isinstance(family_boost, dict): |
| | base *= as_float(family_boost.get(family), 1.0) |
| | min_w = as_float(data_cfg.get("min_loss_weight"), 0.1) |
| | max_w = as_float(data_cfg.get("max_loss_weight"), 8.0) |
| | if min_w > max_w: |
| | min_w, max_w = max_w, min_w |
| | return max(min_w, min(max_w, base)) |
| |
|
| |
|
| | def stage_split_files(stage_cfg: Dict[str, Any], data_cfg: Dict[str, Any]) -> Dict[str, str]: |
| | train_file = as_text(stage_cfg.get("train_file")) or as_text(data_cfg.get("default_train_file")) |
| | valid_file = as_text(stage_cfg.get("validation_file")) or as_text(data_cfg.get("default_validation_file")) |
| | train_path = Path(train_file) |
| | valid_path = Path(valid_file) |
| | if not train_path.exists(): |
| | raise FileNotFoundError(f"Missing train split for stage: {train_path}") |
| | if not valid_path.exists(): |
| | raise FileNotFoundError(f"Missing validation split for stage: {valid_path}") |
| | return {"train": str(train_path), "validation": str(valid_path)} |
| |
|
| |
|
| | def apply_filters(dataset: Dataset, filter_cfg: Dict[str, Any]) -> Dataset: |
| | if not filter_cfg: |
| | return dataset |
| | include_families = set(filter_cfg.get("include_families", []) or []) |
| | exclude_families = set(filter_cfg.get("exclude_families", []) or []) |
| | include_task_types = set(filter_cfg.get("include_task_types", []) or []) |
| | source_datasets = set(filter_cfg.get("source_datasets", []) or []) |
| | require_conjecture_id = bool(filter_cfg.get("require_conjecture_id", False)) |
| | min_sample_weight = filter_cfg.get("min_sample_weight") |
| | min_sample_weight = as_float(min_sample_weight, 0.0) if min_sample_weight is not None else None |
| |
|
| | def _keep(row: Dict[str, Any]) -> bool: |
| | family = as_text(row.get("family")) |
| | if include_families and family not in include_families: |
| | return False |
| | if exclude_families and family in exclude_families: |
| | return False |
| | if include_task_types: |
| | task_type = as_text(row.get("task_type")) |
| | if task_type not in include_task_types: |
| | return False |
| | if source_datasets: |
| | source = as_text(row.get("source_dataset")) |
| | if source not in source_datasets: |
| | return False |
| | if require_conjecture_id: |
| | conjecture_id = as_text(row.get("conjecture_id")) |
| | if not conjecture_id or conjecture_id.lower() == "null": |
| | return False |
| | if min_sample_weight is not None: |
| | sample_weight = as_float(row.get("sample_weight"), 0.0) |
| | if sample_weight < min_sample_weight: |
| | return False |
| | return True |
| |
|
| | return dataset.filter(_keep, desc="Applying stage filters") |
| |
|
| |
|
| | def maybe_select(dataset: Dataset, max_samples: Optional[int]) -> Dataset: |
| | if max_samples is None: |
| | return dataset |
| | if max_samples <= 0: |
| | raise ValueError("max_samples must be positive.") |
| | if max_samples >= len(dataset): |
| | return dataset |
| | return dataset.select(range(max_samples)) |
| |
|
| |
|
| | def tokenize_datasets(raw: DatasetDict, tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> DatasetDict: |
| | max_len = as_int(data_cfg.get("max_seq_length"), 2048) |
| | if max_len < 64: |
| | raise ValueError("data.max_seq_length must be >= 64") |
| | eos = tokenizer.eos_token or "" |
| | remove_columns = raw["train"].column_names |
| |
|
| | def _tokenize(row: Dict[str, Any]) -> Dict[str, Any]: |
| | prompt_text = build_prompt_text(row, tokenizer, data_cfg) |
| | answer_text = build_answer_block(row, data_cfg) |
| | full_text = f"{prompt_text}{answer_text}{eos}" |
| | prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"] |
| | full_enc = tokenizer( |
| | full_text, |
| | add_special_tokens=False, |
| | truncation=True, |
| | max_length=max_len, |
| | ) |
| | input_ids = full_enc["input_ids"] |
| | attention_mask = full_enc["attention_mask"] |
| | if not input_ids: |
| | fallback = tokenizer.eos_token_id |
| | if fallback is None: |
| | fallback = tokenizer.pad_token_id |
| | if fallback is None: |
| | fallback = 0 |
| | input_ids = [fallback] |
| | attention_mask = [1] |
| | labels = [fallback] |
| | else: |
| | prompt_len = min(len(prompt_ids), len(input_ids)) |
| | labels = [-100] * prompt_len + input_ids[prompt_len:] |
| | if prompt_len >= len(input_ids): |
| | labels[-1] = input_ids[-1] |
| | loss_weight = compute_loss_weight(row, data_cfg) |
| | return { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "labels": labels, |
| | "loss_weight": float(loss_weight), |
| | } |
| |
|
| | tokenized = raw.map( |
| | _tokenize, |
| | remove_columns=remove_columns, |
| | desc="Tokenizing prompt/answer pairs", |
| | ) |
| | tokenized = tokenized.filter( |
| | lambda row: any(token != -100 for token in row["labels"]), |
| | desc="Dropping prompt-only rows", |
| | ) |
| | return tokenized |
| |
|
| |
|
| | def build_tokenizer(model_cfg: Dict[str, Any]) -> AutoTokenizer: |
| | base_model = as_text(model_cfg.get("base_model")) |
| | if not base_model: |
| | raise ValueError("model.base_model is required.") |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | base_model, |
| | trust_remote_code=bool(model_cfg.get("trust_remote_code", False)), |
| | use_fast=True, |
| | ) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token |
| | if tokenizer.pad_token is None: |
| | tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) |
| | return tokenizer |
| |
|
| |
|
| | def build_model_and_tokenizer(model_cfg: Dict[str, Any], training_defaults: Dict[str, Any]) -> Tuple[Any, AutoTokenizer]: |
| | base_model = as_text(model_cfg.get("base_model")) |
| | if not base_model: |
| | raise ValueError("model.base_model is required.") |
| |
|
| | use_cuda = torch.cuda.is_available() |
| | requested_bf16 = bool(model_cfg.get("use_bf16", True)) |
| | if use_cuda: |
| | dtype = torch.bfloat16 if requested_bf16 else torch.float16 |
| | else: |
| | dtype = torch.float32 |
| |
|
| | tokenizer = build_tokenizer(model_cfg) |
| |
|
| | model_kwargs: Dict[str, Any] = { |
| | "trust_remote_code": bool(model_cfg.get("trust_remote_code", False)), |
| | "torch_dtype": dtype, |
| | } |
| | attn_impl = as_text(model_cfg.get("attn_implementation")) |
| | if attn_impl: |
| | model_kwargs["attn_implementation"] = attn_impl |
| |
|
| | requested_load_in_4bit = bool(model_cfg.get("load_in_4bit", True)) |
| | load_in_4bit = requested_load_in_4bit and use_cuda |
| | if requested_load_in_4bit and not load_in_4bit: |
| | print("CUDA unavailable. Disabling 4-bit loading and using full-precision CPU fallback.") |
| | if load_in_4bit: |
| | model_kwargs["quantization_config"] = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type=as_text(model_cfg.get("bnb_4bit_quant_type")) or "nf4", |
| | bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True)), |
| | bnb_4bit_compute_dtype=dtype, |
| | ) |
| | model_kwargs["device_map"] = "auto" |
| |
|
| | model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs) |
| | if tokenizer.pad_token_id is not None: |
| | model.config.pad_token_id = tokenizer.pad_token_id |
| | model.config.use_cache = False |
| |
|
| | if load_in_4bit: |
| | model = prepare_model_for_kbit_training( |
| | model, |
| | use_gradient_checkpointing=bool(training_defaults.get("gradient_checkpointing", True)), |
| | ) |
| |
|
| | lora_cfg = model_cfg.get("lora", {}) |
| | peft_cfg = LoraConfig( |
| | r=as_int(lora_cfg.get("r"), 64), |
| | lora_alpha=as_int(lora_cfg.get("alpha"), 128), |
| | lora_dropout=as_float(lora_cfg.get("dropout"), 0.05), |
| | bias=as_text(lora_cfg.get("bias")) or "none", |
| | task_type="CAUSAL_LM", |
| | target_modules=lora_cfg.get("target_modules"), |
| | ) |
| | model = get_peft_model(model, peft_cfg) |
| | model.print_trainable_parameters() |
| | return model, tokenizer |
| |
|
| |
|
| | class WeightedLossCollator: |
| | def __init__(self, tokenizer: AutoTokenizer, model: Any) -> None: |
| | self.base = DataCollatorForSeq2Seq( |
| | tokenizer=tokenizer, |
| | model=model, |
| | label_pad_token_id=-100, |
| | pad_to_multiple_of=8, |
| | ) |
| |
|
| | def __call__(self, features: list[Dict[str, Any]]) -> Dict[str, Any]: |
| | weights = [float(feature.pop("loss_weight", 1.0)) for feature in features] |
| | batch = self.base(features) |
| | batch["loss_weight"] = torch.tensor(weights, dtype=torch.float32) |
| | return batch |
| |
|
| |
|
| | class WeightedLossTrainer(Trainer): |
| | def _get_train_sampler(self): |
| | if self.train_dataset is None: |
| | return None |
| | if "loss_weight" not in self.train_dataset.column_names: |
| | return super()._get_train_sampler() |
| | weights = self.train_dataset["loss_weight"] |
| | if not weights: |
| | return super()._get_train_sampler() |
| | weight_tensor = torch.tensor(weights, dtype=torch.double) |
| | return WeightedRandomSampler( |
| | weights=weight_tensor, |
| | num_samples=len(weight_tensor), |
| | replacement=True, |
| | ) |
| |
|
| | def compute_loss( |
| | self, |
| | model: Any, |
| | inputs: Dict[str, Any], |
| | return_outputs: bool = False, |
| | num_items_in_batch: Optional[torch.Tensor] = None, |
| | ): |
| | loss_weight = inputs.pop("loss_weight", None) |
| | labels = inputs.get("labels") |
| | if labels is None: |
| | return super().compute_loss( |
| | model=model, |
| | inputs=inputs, |
| | return_outputs=return_outputs, |
| | num_items_in_batch=num_items_in_batch, |
| | ) |
| |
|
| | model_inputs = {k: v for k, v in inputs.items() if k != "labels"} |
| | outputs = model(**model_inputs) |
| | logits = outputs.logits |
| |
|
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | token_losses = torch.nn.functional.cross_entropy( |
| | shift_logits.view(-1, shift_logits.size(-1)), |
| | shift_labels.view(-1), |
| | ignore_index=-100, |
| | reduction="none", |
| | ).view(shift_labels.size()) |
| | token_mask = shift_labels.ne(-100).float() |
| | seq_den = token_mask.sum(dim=1).clamp(min=1.0) |
| | seq_loss = (token_losses * token_mask).sum(dim=1) / seq_den |
| |
|
| | if loss_weight is not None: |
| | normalized = loss_weight.to(seq_loss.device).float().clamp(min=0.05) |
| | loss = (seq_loss * normalized).sum() / normalized.sum() |
| | else: |
| | loss = seq_loss.mean() |
| |
|
| | if return_outputs: |
| | return loss, outputs |
| | return loss |
| |
|
| |
|
| | def build_training_args( |
| | output_dir: Path, |
| | training_cfg: Dict[str, Any], |
| | use_bf16: bool, |
| | has_eval_split: bool, |
| | ) -> TrainingArguments: |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | use_cuda = torch.cuda.is_available() |
| | bf16_runtime = bool(use_cuda and use_bf16) |
| | fp16_runtime = bool(use_cuda and not bf16_runtime) |
| | return TrainingArguments( |
| | output_dir=str(output_dir), |
| | num_train_epochs=as_float(training_cfg.get("num_train_epochs"), 1.0), |
| | per_device_train_batch_size=as_int(training_cfg.get("per_device_train_batch_size"), 1), |
| | per_device_eval_batch_size=as_int(training_cfg.get("per_device_eval_batch_size"), 1), |
| | gradient_accumulation_steps=as_int(training_cfg.get("gradient_accumulation_steps"), 1), |
| | learning_rate=as_float(training_cfg.get("learning_rate"), 2e-5), |
| | weight_decay=as_float(training_cfg.get("weight_decay"), 0.0), |
| | warmup_ratio=as_float(training_cfg.get("warmup_ratio"), 0.0), |
| | lr_scheduler_type=as_text(training_cfg.get("lr_scheduler_type")) or "cosine", |
| | max_grad_norm=as_float(training_cfg.get("max_grad_norm"), 1.0), |
| | gradient_checkpointing=bool(training_cfg.get("gradient_checkpointing", True)), |
| | logging_steps=as_int(training_cfg.get("logging_steps"), 10), |
| | save_steps=as_int(training_cfg.get("save_steps"), 500), |
| | save_total_limit=as_int(training_cfg.get("save_total_limit"), 3), |
| | dataloader_num_workers=as_int(training_cfg.get("dataloader_num_workers"), 0), |
| | seed=as_int(training_cfg.get("seed"), 17), |
| | bf16=bf16_runtime, |
| | fp16=fp16_runtime, |
| | remove_unused_columns=False, |
| | report_to="none", |
| | evaluation_strategy="steps" if has_eval_split else "no", |
| | eval_steps=as_int(training_cfg.get("eval_steps"), 500) if has_eval_split else None, |
| | ) |
| |
|
| |
|
| | def push_folder( |
| | api: HfApi, |
| | repo_id: str, |
| | folder_path: Path, |
| | commit_message: str, |
| | path_in_repo: Optional[str] = None, |
| | ) -> None: |
| | kwargs: Dict[str, Any] = { |
| | "repo_id": repo_id, |
| | "repo_type": "model", |
| | "folder_path": str(folder_path), |
| | "commit_message": commit_message, |
| | } |
| | if path_in_repo: |
| | kwargs["path_in_repo"] = path_in_repo |
| | api.upload_folder(**kwargs) |
| |
|
| |
|
| | def extract_final_eval_loss(stage_reports: List[Dict[str, Any]]) -> Optional[float]: |
| | for report in reversed(stage_reports): |
| | eval_metrics = report.get("eval_metrics") |
| | if not isinstance(eval_metrics, dict): |
| | continue |
| | value = eval_metrics.get("eval_loss") |
| | if value is None: |
| | continue |
| | try: |
| | return float(value) |
| | except (TypeError, ValueError): |
| | continue |
| | return None |
| |
|
| |
|
| | def release_model_memory(model: Any) -> None: |
| | try: |
| | model.to("cpu") |
| | except Exception: |
| | pass |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| |
|
| | def run_post_eval( |
| | cfg: Dict[str, Any], |
| | config_path: Path, |
| | output_root: Path, |
| | final_adapter_dir: Path, |
| | ) -> Optional[Dict[str, Any]]: |
| | post_cfg = cfg.get("post_eval", {}) |
| | if not as_bool(post_cfg.get("enabled"), False): |
| | return None |
| |
|
| | eval_script = DEFAULT_EVAL_SCRIPT |
| | if not eval_script.exists(): |
| | raise FileNotFoundError(f"Post-eval enabled but eval script is missing: {eval_script}") |
| |
|
| | data_cfg = cfg.get("data", {}) |
| | eval_file = Path( |
| | as_text(post_cfg.get("eval_file")) |
| | or as_text(data_cfg.get("default_validation_file")) |
| | or "data/releases/v1/test.parquet" |
| | ) |
| | if not eval_file.exists(): |
| | raise FileNotFoundError(f"Post-eval file not found: {eval_file}") |
| |
|
| | output_json = Path(as_text(post_cfg.get("output_json")) or str(output_root / "post_eval_report.json")) |
| | base_model = as_text(cfg.get("model", {}).get("base_model")) |
| | if not base_model: |
| | raise ValueError("model.base_model is required for post-eval.") |
| |
|
| | cmd = [ |
| | sys.executable, |
| | str(eval_script), |
| | "--config", |
| | str(config_path), |
| | "--base-model", |
| | base_model, |
| | "--adapter-path", |
| | str(final_adapter_dir), |
| | "--eval-file", |
| | str(eval_file), |
| | "--max-samples", |
| | str(as_int(post_cfg.get("max_samples"), 300)), |
| | "--k", |
| | str(as_int(post_cfg.get("k"), 4)), |
| | "--max-new-tokens", |
| | str(as_int(post_cfg.get("max_new_tokens"), 256)), |
| | "--temperature", |
| | str(as_float(post_cfg.get("temperature"), 0.7)), |
| | "--top-p", |
| | str(as_float(post_cfg.get("top_p"), 0.95)), |
| | "--seed", |
| | str(as_int(post_cfg.get("seed"), as_int(cfg.get("global", {}).get("seed"), 17))), |
| | "--output-json", |
| | str(output_json), |
| | ] |
| | print(f"Running post-training eval: {' '.join(cmd)}") |
| | completed = subprocess.run(cmd, check=False) |
| | if completed.returncode != 0: |
| | raise RuntimeError(f"Post-training evaluation failed with exit code {completed.returncode}.") |
| |
|
| | if not output_json.exists(): |
| | raise FileNotFoundError(f"Post-eval report was not created: {output_json}") |
| |
|
| | report = json.loads(output_json.read_text(encoding="utf-8")) |
| | return { |
| | "enabled": True, |
| | "report_path": str(output_json), |
| | "report": report, |
| | "command": cmd, |
| | } |
| |
|
| |
|
| | def evaluate_quality_gate( |
| | stage_reports: List[Dict[str, Any]], |
| | post_eval_result: Optional[Dict[str, Any]], |
| | gate_cfg: Dict[str, Any], |
| | ) -> Dict[str, Any]: |
| | enabled = as_bool(gate_cfg.get("enabled"), False) |
| | result: Dict[str, Any] = { |
| | "enabled": enabled, |
| | "passed": True, |
| | "violations": [], |
| | "checks": [], |
| | } |
| | if not enabled: |
| | return result |
| |
|
| | violations: List[str] = [] |
| | checks: List[Dict[str, Any]] = [] |
| |
|
| | final_eval_loss = extract_final_eval_loss(stage_reports) |
| | max_final_eval_loss = gate_cfg.get("max_final_eval_loss") |
| | if max_final_eval_loss is not None: |
| | threshold = as_float(max_final_eval_loss, 0.0) |
| | checks.append( |
| | { |
| | "name": "max_final_eval_loss", |
| | "actual": final_eval_loss, |
| | "threshold": threshold, |
| | } |
| | ) |
| | if final_eval_loss is None: |
| | violations.append("Final stage eval_loss is missing for max_final_eval_loss gate.") |
| | elif final_eval_loss > threshold: |
| | violations.append( |
| | f"Final eval_loss {final_eval_loss:.4f} exceeds threshold {threshold:.4f}." |
| | ) |
| |
|
| | report: Optional[Dict[str, Any]] = None |
| | if isinstance(post_eval_result, dict): |
| | loaded = post_eval_result.get("report") |
| | if isinstance(loaded, dict): |
| | report = loaded |
| |
|
| | require_post_eval = as_bool(gate_cfg.get("require_post_eval"), False) |
| | if report is None: |
| | if require_post_eval: |
| | violations.append("Quality gate requires post-eval metrics, but post-eval report is missing.") |
| | else: |
| | evaluated_rows = as_int(report.get("evaluated_rows"), 0) |
| | min_rows = as_int(gate_cfg.get("min_evaluated_rows"), 0) |
| | checks.append( |
| | { |
| | "name": "min_evaluated_rows", |
| | "actual": evaluated_rows, |
| | "threshold": min_rows, |
| | } |
| | ) |
| | if evaluated_rows < min_rows: |
| | violations.append( |
| | f"Post-eval rows {evaluated_rows} is below minimum required {min_rows}." |
| | ) |
| |
|
| | min_pass_at_1_raw = gate_cfg.get("min_pass_at_1") |
| | if min_pass_at_1_raw is not None: |
| | min_pass_at_1 = as_float(min_pass_at_1_raw, 0.0) |
| | pass_at_1 = as_float(report.get("pass_at_1"), 0.0) |
| | checks.append( |
| | { |
| | "name": "min_pass_at_1", |
| | "actual": pass_at_1, |
| | "threshold": min_pass_at_1, |
| | } |
| | ) |
| | if pass_at_1 < min_pass_at_1: |
| | violations.append( |
| | f"pass@1 {pass_at_1:.4f} is below threshold {min_pass_at_1:.4f}." |
| | ) |
| |
|
| | min_pass_at_k_raw = gate_cfg.get("min_pass_at_k") |
| | if min_pass_at_k_raw is not None: |
| | min_pass_at_k = as_float(min_pass_at_k_raw, 0.0) |
| | pass_at_k = as_float(report.get("pass_at_k"), 0.0) |
| | checks.append( |
| | { |
| | "name": "min_pass_at_k", |
| | "actual": pass_at_k, |
| | "threshold": min_pass_at_k, |
| | } |
| | ) |
| | if pass_at_k < min_pass_at_k: |
| | violations.append( |
| | f"pass@k {pass_at_k:.4f} is below threshold {min_pass_at_k:.4f}." |
| | ) |
| |
|
| | family_requirements = gate_cfg.get("required_family_pass_at_k", {}) |
| | family_metrics = report.get("family_metrics", {}) |
| | if isinstance(family_requirements, dict): |
| | for family, threshold_raw in family_requirements.items(): |
| | threshold = as_float(threshold_raw, 0.0) |
| | actual = None |
| | if isinstance(family_metrics, dict): |
| | family_row = family_metrics.get(family) |
| | if isinstance(family_row, dict): |
| | try: |
| | actual = float(family_row.get("pass_at_k")) |
| | except (TypeError, ValueError): |
| | actual = None |
| | checks.append( |
| | { |
| | "name": f"family_pass_at_k:{family}", |
| | "actual": actual, |
| | "threshold": threshold, |
| | } |
| | ) |
| | if actual is None: |
| | violations.append(f"Missing pass@k metric for required family '{family}'.") |
| | elif actual < threshold: |
| | violations.append( |
| | f"Family '{family}' pass@k {actual:.4f} is below threshold {threshold:.4f}." |
| | ) |
| |
|
| | result["violations"] = violations |
| | result["checks"] = checks |
| | result["passed"] = len(violations) == 0 |
| | return result |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | cfg = load_config(args.config) |
| | apply_overrides(cfg, args) |
| |
|
| | seed = as_int(cfg.get("global", {}).get("seed"), 17) |
| | set_seed(seed) |
| |
|
| | output_root = Path(as_text(cfg.get("global", {}).get("output_root")) or "runs/math-conjecture-sota") |
| | output_root.mkdir(parents=True, exist_ok=True) |
| |
|
| | token, username = resolve_auth(cfg) |
| | repo_id = resolve_repo_id(cfg, username=username, output_root=output_root) |
| | push_to_hub_requested = bool(cfg.get("hub", {}).get("push_to_hub", False)) |
| | if args.dry_run and push_to_hub_requested: |
| | print("Dry-run enabled. Disabling push_to_hub for this run.") |
| | push_to_hub_requested = push_to_hub_requested and not args.dry_run |
| |
|
| | if push_to_hub_requested: |
| | if token is None: |
| | raise ValueError("Hub push requested but token is missing.") |
| | if repo_id is None: |
| | raise ValueError("Hub push requested but repo_id is missing.") |
| |
|
| | if args.dry_run: |
| | tokenizer = build_tokenizer(cfg["model"]) |
| | model = None |
| | else: |
| | model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {})) |
| | if torch.cuda.is_available(): |
| | print("Compute mode: GPU") |
| | else: |
| | print("Compute mode: CPU fallback (no CUDA detected)") |
| |
|
| | data_cfg = cfg["data"] |
| | stage_reports: List[Dict[str, Any]] = [] |
| |
|
| | start_stage = max(1, args.start_stage) |
| | stages = cfg["stages"] |
| | end_stage = len(stages) |
| | if args.max_stages is not None: |
| | if args.max_stages <= 0: |
| | raise ValueError("--max-stages must be positive.") |
| | end_stage = min(end_stage, start_stage + args.max_stages - 1) |
| |
|
| | for index in range(start_stage, end_stage + 1): |
| | stage = stages[index - 1] |
| | stage_name = as_text(stage.get("name")) or f"stage_{index:02d}" |
| | stage_slug = f"{index:02d}_{stage_name.replace(' ', '_')}" |
| | stage_output_dir = output_root / stage_slug |
| | print(f"[stage {index}] Starting: {stage_name}") |
| |
|
| | split_files = stage_split_files(stage, data_cfg) |
| | raw = load_dataset("parquet", data_files=split_files) |
| | train_rows_before = len(raw["train"]) |
| | valid_rows_before = len(raw["validation"]) |
| |
|
| | filters = stage.get("filters", {}) |
| | raw["train"] = apply_filters(raw["train"], filters) |
| | raw["validation"] = apply_filters(raw["validation"], filters) |
| | train_rows_after_filter = len(raw["train"]) |
| | valid_rows_after_filter = len(raw["validation"]) |
| |
|
| | raw["train"] = maybe_select(raw["train"], stage.get("max_train_samples")) |
| | raw["validation"] = maybe_select(raw["validation"], stage.get("max_eval_samples")) |
| | train_rows_selected = len(raw["train"]) |
| | valid_rows_selected = len(raw["validation"]) |
| |
|
| | print( |
| | f"[stage {index}] rows train: {train_rows_before} -> {train_rows_after_filter} -> {train_rows_selected}; " |
| | f"validation: {valid_rows_before} -> {valid_rows_after_filter} -> {valid_rows_selected}" |
| | ) |
| | if len(raw["train"]) == 0: |
| | raise ValueError(f"Stage {stage_slug} has zero train rows after filtering.") |
| |
|
| | if args.dry_run: |
| | sample_row = raw["train"][0] |
| | _ = build_prompt_text(sample_row, tokenizer, data_cfg) |
| | _ = build_answer_block(sample_row, data_cfg) |
| | stage_reports.append( |
| | { |
| | "stage_index": index, |
| | "stage_name": stage_name, |
| | "stage_slug": stage_slug, |
| | "mode": "dry_run", |
| | "train_rows_before_filter": train_rows_before, |
| | "validation_rows_before_filter": valid_rows_before, |
| | "train_rows_after_filter": train_rows_after_filter, |
| | "validation_rows_after_filter": valid_rows_after_filter, |
| | "train_rows_selected": train_rows_selected, |
| | "validation_rows_selected": valid_rows_selected, |
| | } |
| | ) |
| | print(f"[stage {index}] Dry-run checks passed.") |
| | continue |
| |
|
| | tokenized = tokenize_datasets(raw, tokenizer, data_cfg) |
| | train_dataset = tokenized["train"] |
| | eval_dataset = tokenized["validation"] if len(tokenized["validation"]) > 0 else None |
| |
|
| | merged_training = dict(cfg.get("training_defaults", {})) |
| | merged_training.update(stage.get("training", {})) |
| | merged_training["seed"] = seed |
| | training_args = build_training_args( |
| | output_dir=stage_output_dir, |
| | training_cfg=merged_training, |
| | use_bf16=bool(cfg["model"].get("use_bf16", True)), |
| | has_eval_split=eval_dataset is not None, |
| | ) |
| | collator = WeightedLossCollator(tokenizer=tokenizer, model=model) |
| | trainer = WeightedLossTrainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_dataset, |
| | eval_dataset=eval_dataset, |
| | tokenizer=tokenizer, |
| | data_collator=collator, |
| | ) |
| |
|
| | train_result = trainer.train() |
| | trainer.log_metrics("train", train_result.metrics) |
| | trainer.save_metrics("train", train_result.metrics) |
| | trainer.save_state() |
| |
|
| | eval_metrics = None |
| | if eval_dataset is not None: |
| | eval_metrics = trainer.evaluate() |
| | trainer.log_metrics("eval", eval_metrics) |
| | trainer.save_metrics("eval", eval_metrics) |
| |
|
| | trainer.save_model(str(stage_output_dir)) |
| | tokenizer.save_pretrained(str(stage_output_dir)) |
| |
|
| | stage_reports.append( |
| | { |
| | "stage_index": index, |
| | "stage_name": stage_name, |
| | "output_dir": str(stage_output_dir), |
| | "train_rows_before_filter": train_rows_before, |
| | "validation_rows_before_filter": valid_rows_before, |
| | "train_rows_after_filter": train_rows_after_filter, |
| | "validation_rows_after_filter": valid_rows_after_filter, |
| | "train_rows_selected": train_rows_selected, |
| | "validation_rows_selected": valid_rows_selected, |
| | "train_rows": len(train_dataset), |
| | "eval_rows": len(eval_dataset) if eval_dataset is not None else 0, |
| | "train_metrics": train_result.metrics, |
| | "eval_metrics": eval_metrics, |
| | } |
| | ) |
| | print( |
| | f"[stage {index}] Completed: train_rows={len(train_dataset)} " |
| | f"eval_rows={len(eval_dataset) if eval_dataset is not None else 0} output={stage_output_dir}" |
| | ) |
| |
|
| | if args.dry_run: |
| | summary = { |
| | "mode": "dry_run", |
| | "config_path": str(args.config), |
| | "seed": seed, |
| | "start_stage": start_stage, |
| | "end_stage": end_stage, |
| | "stages_ran": stage_reports, |
| | } |
| | summary_path = output_root / "dry_run_summary.json" |
| | summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8") |
| | print("Dry-run complete. No training or model push was performed.") |
| | print(f"Dry-run summary: {summary_path}") |
| | return |
| |
|
| | final_dir = output_root / "final_adapter" |
| | final_dir.mkdir(parents=True, exist_ok=True) |
| | assert model is not None |
| | model.save_pretrained(str(final_dir)) |
| | tokenizer.save_pretrained(str(final_dir)) |
| |
|
| | release_model_memory(model) |
| | del model |
| |
|
| | post_eval_result = run_post_eval( |
| | cfg=cfg, |
| | config_path=args.config, |
| | output_root=output_root, |
| | final_adapter_dir=final_dir, |
| | ) |
| |
|
| | quality_gate = evaluate_quality_gate( |
| | stage_reports=stage_reports, |
| | post_eval_result=post_eval_result, |
| | gate_cfg=cfg.get("quality_gate", {}), |
| | ) |
| |
|
| | push_to_hub_performed = push_to_hub_requested |
| | push_block_reason: Optional[str] = None |
| | if push_to_hub_requested and not quality_gate.get("passed", True): |
| | push_to_hub_performed = False |
| | push_block_reason = "quality_gate_failed" |
| | print("Quality gate failed; skipping hub push for this run.") |
| |
|
| | summary: Dict[str, Any] = { |
| | "config_path": str(args.config), |
| | "repo_id": repo_id, |
| | "seed": seed, |
| | "stages_ran": stage_reports, |
| | "final_adapter_dir": str(final_dir), |
| | "quality_gate": quality_gate, |
| | "push": { |
| | "requested": bool(push_to_hub_requested), |
| | "performed": bool(push_to_hub_performed), |
| | "block_reason": push_block_reason, |
| | }, |
| | } |
| |
|
| | if post_eval_result is not None: |
| | report = post_eval_result.get("report", {}) |
| | summary["post_eval"] = { |
| | "report_path": post_eval_result.get("report_path"), |
| | "evaluated_rows": report.get("evaluated_rows"), |
| | "k": report.get("k"), |
| | "pass_at_1": report.get("pass_at_1"), |
| | "pass_at_k": report.get("pass_at_k"), |
| | "exact_at_k": report.get("exact_at_k"), |
| | "composite_score": report.get("composite_score"), |
| | } |
| |
|
| | summary_path = output_root / "training_summary.json" |
| | summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8") |
| |
|
| | if push_to_hub_performed and repo_id is not None and token is not None: |
| | api = HfApi(token=token) |
| | api.create_repo( |
| | repo_id=repo_id, |
| | repo_type="model", |
| | private=bool(cfg.get("hub", {}).get("private", False)), |
| | exist_ok=True, |
| | ) |
| | commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload SOTA curriculum adapter." |
| | push_folder(api, repo_id, final_dir, commit_message=commit_message) |
| |
|
| | if bool(cfg.get("hub", {}).get("upload_stage_checkpoints", False)): |
| | for report in stage_reports: |
| | stage_dir_raw = report.get("output_dir") |
| | if not stage_dir_raw: |
| | continue |
| | stage_dir = Path(stage_dir_raw) |
| | path_in_repo = f"checkpoints/{stage_dir.name}" |
| | push_folder( |
| | api, |
| | repo_id, |
| | stage_dir, |
| | commit_message=f"Upload stage checkpoint {report.get('stage_name', stage_dir.name)}", |
| | path_in_repo=path_in_repo, |
| | ) |
| |
|
| | api.upload_file( |
| | path_or_fileobj=str(summary_path), |
| | path_in_repo="training_summary.json", |
| | repo_id=repo_id, |
| | repo_type="model", |
| | commit_message="Upload training summary for SOTA curriculum run.", |
| | ) |
| |
|
| | if post_eval_result is not None and post_eval_result.get("report_path"): |
| | api.upload_file( |
| | path_or_fileobj=str(post_eval_result["report_path"]), |
| | path_in_repo="post_eval_report.json", |
| | repo_id=repo_id, |
| | repo_type="model", |
| | commit_message="Upload post-training evaluation report.", |
| | ) |
| |
|
| | print(f"Pushed training artifacts to https://huggingface.co/{repo_id}") |
| |
|
| | print(f"Training complete. Final adapter: {final_dir}") |
| | print(f"Training summary: {summary_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|