| """ |
| Build/inspect script for Component 4 model architecture. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
| from typing import Any, Dict |
|
|
| import yaml |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from src.model_architecture.code_transformer import ( |
| CodeTransformerLM, |
| ModelConfig, |
| get_model_presets, |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Build and inspect Component 4 model.") |
| parser.add_argument( |
| "--config", |
| default="configs/component4_model_config.yaml", |
| help="Path to model YAML config.", |
| ) |
| parser.add_argument( |
| "--save_summary", |
| default="artifacts/model/component4_model_summary.json", |
| help="Where to save model summary JSON.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def load_yaml(path: Path) -> Dict[str, Any]: |
| if not path.exists(): |
| raise FileNotFoundError(f"Model config not found: {path}") |
| with path.open("r", encoding="utf-8") as f: |
| data = yaml.safe_load(f) |
| if not isinstance(data, dict): |
| raise ValueError("Invalid YAML format in model config.") |
| return data |
|
|
|
|
| def build_config(cfg_data: Dict[str, Any]) -> ModelConfig: |
| preset = cfg_data.get("preset") |
| model_cfg = cfg_data.get("model", {}) |
| if not isinstance(model_cfg, dict): |
| raise ValueError("Config key 'model' must be an object.") |
|
|
| base = None |
| if preset: |
| presets = get_model_presets() |
| if preset not in presets: |
| raise ValueError(f"Unknown preset '{preset}'. Available: {list(presets.keys())}") |
| base = presets[preset] |
|
|
| if base is None: |
| return ModelConfig(**model_cfg) |
|
|
| merged = { |
| "vocab_size": base.vocab_size, |
| "max_seq_len": base.max_seq_len, |
| "d_model": base.d_model, |
| "n_layers": base.n_layers, |
| "n_heads": base.n_heads, |
| "d_ff": base.d_ff, |
| "dropout": base.dropout, |
| "tie_embeddings": base.tie_embeddings, |
| "gradient_checkpointing": base.gradient_checkpointing, |
| "init_std": base.init_std, |
| "rms_norm_eps": base.rms_norm_eps, |
| } |
| merged.update(model_cfg) |
| return ModelConfig(**merged) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| try: |
| cfg_data = load_yaml(Path(args.config)) |
| model_cfg = build_config(cfg_data) |
| model = CodeTransformerLM(model_cfg) |
| summary = model.summary() |
|
|
| save_path = Path(args.save_summary) |
| save_path.parent.mkdir(parents=True, exist_ok=True) |
| with save_path.open("w", encoding="utf-8") as f: |
| json.dump(summary, f, indent=2) |
|
|
| print("Component 4 model build completed.") |
| print(f"Preset: {cfg_data.get('preset')}") |
| print(f"Parameters: {summary['num_parameters']:,}") |
| print(f"Saved summary: {save_path}") |
| except Exception as exc: |
| print("Component 4 model build failed.") |
| print(f"What went wrong: {exc}") |
| print("Fix suggestion: check config values (especially d_model and n_heads divisibility).") |
| raise SystemExit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|