| // model_config.h — Qwen3 hparams loaded from HF config.json, plus TP-derived per-rank sizes. | |
| struct ModelConfig { | |
| // ---- Raw hparams from config.json ---- | |
| int64_t vocab_size = 0; | |
| int64_t hidden_size = 0; // D | |
| int64_t intermediate_size = 0; // dense FFN (not used for MoE layers; kept for completeness) | |
| int64_t moe_intermediate_size = 0; // I per expert | |
| int64_t num_hidden_layers = 0; // = 94 for Qwen3-235B | |
| int64_t num_attention_heads = 0; // = 64 | |
| int64_t num_key_value_heads = 0; // = 4 (GQA) | |
| int64_t head_dim = 0; // = 128 | |
| int64_t num_experts = 0; // = 128 | |
| int64_t num_experts_per_tok = 0; // top_k = 8 | |
| int64_t max_position_embeddings = 0; | |
| float rope_theta = 0.0f; | |
| float rms_norm_eps = 1e-6f; | |
| bool norm_topk_prob = true; | |
| bool tie_word_embeddings = false; | |
| int64_t bos_token_id = 0; | |
| int64_t eos_token_id = 0; | |
| // ---- TP configuration ---- | |
| int tp_size = 1; | |
| int tp_rank = 0; | |
| // ---- Derived per-rank sizes ---- | |
| // Attention Q: split along num_heads (head-parallel) | |
| // n_heads_per_rank = num_attention_heads / tp_size | |
| // q_dim_per_rank = n_heads_per_rank * head_dim | |
| int64_t n_heads_per_rank = 0; | |
| int64_t q_dim_per_rank = 0; | |
| // Attention KV: GQA with num_kv_heads < tp_size needs special handling. | |
| // For Qwen3-235B: num_kv_heads = 4, tp_size = 16 → each KV head is replicated 4× across ranks. | |
| // Simple scheme: each rank computes ALL kv heads (small, 4 × 128 = 512 features) | |
| // then slices attention output for its own q heads. | |
| // Alternative: split KV heads if tp_size <= num_kv_heads. | |
| int64_t n_kv_heads_per_rank = 0; | |
| int64_t kv_dim_per_rank = 0; | |
| // MoE: intermediate dim split. Each rank holds 1/tp_size of experts' intermediate_size. | |
| // i_per_rank = moe_intermediate_size / tp_size | |
| int64_t i_per_rank = 0; | |
| bool load_from_json(const std::string& path); | |
| void compute_derived(int tp_size, int tp_rank); | |
| std::string describe() const; | |
| }; | |