| |
| |
| #include "model_config.h" |
| #include "safetensors_loader.h" |
|
|
| #include <cstdio> |
| #include <sstream> |
| #include <string> |
|
|
| int main(int argc, char** argv) { |
| std::string dir = argc > 1 ? argv[1] |
| : "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16"; |
| int tp_size = argc > 2 ? std::atoi(argv[2]) : 16; |
| int tp_rank = argc > 3 ? std::atoi(argv[3]) : 0; |
|
|
| ModelConfig cfg; |
| if (!cfg.load_from_json(dir + "/config.json")) return 1; |
| cfg.compute_derived(tp_size, tp_rank); |
| printf("%s\n", cfg.describe().c_str()); |
|
|
| SafetensorsLoader loader; |
| if (!loader.open(dir)) return 1; |
|
|
| |
| int missing = 0, shape_mismatch = 0; |
| auto check_shape = [&](const std::string& name, const std::vector<int64_t>& expected) { |
| auto* m = loader.get(name); |
| if (!m) { |
| printf(" MISSING: %s\n", name.c_str()); |
| missing++; |
| return; |
| } |
| if (m->shape != expected) { |
| printf(" SHAPE MISMATCH: %s got=[", name.c_str()); |
| for (size_t i = 0; i < m->shape.size(); i++) printf("%s%ld", i ? "," : "", m->shape[i]); |
| printf("] want=["); |
| for (size_t i = 0; i < expected.size(); i++) printf("%s%ld", i ? "," : "", expected[i]); |
| printf("]\n"); |
| shape_mismatch++; |
| } |
| }; |
|
|
| |
| check_shape("model.embed_tokens.weight", {cfg.vocab_size, cfg.hidden_size}); |
| check_shape("lm_head.weight", {cfg.vocab_size, cfg.hidden_size}); |
| check_shape("model.norm.weight", {cfg.hidden_size}); |
|
|
| |
| int64_t q_full = cfg.num_attention_heads * cfg.head_dim; |
| int64_t kv_full = cfg.num_key_value_heads * cfg.head_dim; |
| for (int L = 0; L < cfg.num_hidden_layers; L++) { |
| auto base = "model.layers." + std::to_string(L); |
| check_shape(base + ".input_layernorm.weight", {cfg.hidden_size}); |
| check_shape(base + ".post_attention_layernorm.weight", {cfg.hidden_size}); |
| check_shape(base + ".self_attn.q_proj.weight", {q_full, cfg.hidden_size}); |
| check_shape(base + ".self_attn.k_proj.weight", {kv_full, cfg.hidden_size}); |
| check_shape(base + ".self_attn.v_proj.weight", {kv_full, cfg.hidden_size}); |
| check_shape(base + ".self_attn.o_proj.weight", {cfg.hidden_size, q_full}); |
| |
| check_shape(base + ".self_attn.q_norm.weight", {cfg.head_dim}); |
| check_shape(base + ".self_attn.k_norm.weight", {cfg.head_dim}); |
| |
| check_shape(base + ".mlp.gate.weight", {cfg.num_experts, cfg.hidden_size}); |
| |
| for (int e : {0, 1, 63, 127}) { |
| auto ebase = base + ".mlp.experts." + std::to_string(e); |
| check_shape(ebase + ".gate_proj.weight", {cfg.moe_intermediate_size, cfg.hidden_size}); |
| check_shape(ebase + ".up_proj.weight", {cfg.moe_intermediate_size, cfg.hidden_size}); |
| check_shape(ebase + ".down_proj.weight", {cfg.hidden_size, cfg.moe_intermediate_size}); |
| } |
| } |
|
|
| |
| int64_t attn_bytes_per_rank = 0; |
| attn_bytes_per_rank += cfg.q_dim_per_rank * cfg.hidden_size * 2; |
| attn_bytes_per_rank += cfg.kv_dim_per_rank * cfg.hidden_size * 2; |
| attn_bytes_per_rank += cfg.kv_dim_per_rank * cfg.hidden_size * 2; |
| attn_bytes_per_rank += cfg.hidden_size * cfg.q_dim_per_rank * 2; |
| attn_bytes_per_rank *= cfg.num_hidden_layers; |
|
|
| int64_t moe_bytes_per_rank = 0; |
| |
| moe_bytes_per_rank += 2 * cfg.num_experts * cfg.i_per_rank * cfg.hidden_size * 2; |
| |
| moe_bytes_per_rank += cfg.num_experts * cfg.hidden_size * cfg.i_per_rank * 2; |
| moe_bytes_per_rank *= cfg.num_hidden_layers; |
|
|
| int64_t embed_bytes = cfg.vocab_size * cfg.hidden_size * 2 * 2; |
| int64_t router_bytes = cfg.num_experts * cfg.hidden_size * 2 * cfg.num_hidden_layers; |
| int64_t norm_bytes = cfg.hidden_size * 2 * (2 * cfg.num_hidden_layers + 1); |
| int64_t total_per_rank = attn_bytes_per_rank + moe_bytes_per_rank + embed_bytes + router_bytes + norm_bytes; |
|
|
| printf("\nPer-rank weight memory estimate (BF16, TP=%d):\n", tp_size); |
| printf(" attention: %.2f GB\n", attn_bytes_per_rank / 1e9); |
| printf(" MoE exps: %.2f GB\n", moe_bytes_per_rank / 1e9); |
| printf(" embed+head: %.2f GB (replicated)\n", embed_bytes / 1e9); |
| printf(" router: %.2f MB (replicated)\n", router_bytes / 1e6); |
| printf(" norms: %.2f MB (replicated)\n", norm_bytes / 1e6); |
| printf(" TOTAL: %.2f GB\n", total_per_rank / 1e9); |
|
|
| int errors = missing + shape_mismatch; |
| printf("\nMissing: %d, Shape mismatch: %d\n", missing, shape_mismatch); |
| printf("%s\n", errors == 0 ? "=== test_model_config PASS ===" |
| : "=== test_model_config FAIL ==="); |
| return errors == 0 ? 0 : 1; |
| } |
|
|