| |
| |
| |
| |
| |
| #include "runner.h" |
|
|
| #include <cstdio> |
| #include <cstring> |
| #include <vector> |
| #include <cmath> |
|
|
| static float bf16_to_float(uint16_t x) { |
| uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; |
| } |
|
|
| int main() { |
| const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16"; |
| int tp_rank = 0, tp_size = 1; |
| if (const char* v = std::getenv("TP_RANK")) tp_rank = std::atoi(v); |
| if (const char* v = std::getenv("TP_SIZE")) tp_size = std::atoi(v); |
| bool is_master = tp_rank == 0; |
|
|
| Runner r; |
| if (!r.init(model_dir, tp_size, tp_rank, 94, 512)) return 1; |
| const int64_t V = r.cfg().vocab_size; |
|
|
| |
| std::vector<int32_t> prompt = {785, 6722, 315, 9625, 374}; |
| DeviceBuffer logits0; |
| r.prefill(prompt.data(), prompt.size(), logits0); |
| std::vector<uint16_t> h_last0(V); |
| if (is_master) ACL_CHECK(aclrtMemcpy(h_last0.data(), V*2, logits0.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| int next0 = 0; |
| if (is_master) { |
| float best = -1e30; for (int i = 0; i < V; i++) { float v = bf16_to_float(h_last0[i]); if (v > best) { best = v; next0 = i; } } |
| } |
| |
| int32_t token_seq[4]; |
| if (is_master) token_seq[0] = next0; |
|
|
| |
| std::vector<std::vector<uint16_t>> seq_logits(4); |
| for (int i = 0; i < 4; i++) seq_logits[i].resize(V); |
|
|
| |
| |
| std::vector<int32_t> seq_tokens = {next0, 100, 200, 300}; |
|
|
| for (int i = 0; i < 4; i++) { |
| DeviceBuffer out; |
| r.decode(seq_tokens[i], out); |
| if (is_master) ACL_CHECK(aclrtMemcpy(seq_logits[i].data(), V*2, out.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| } |
| int64_t past_after_seq = r.past_len(); |
|
|
| |
| r.reset_cache(); |
| DeviceBuffer logits_reprefill; |
| r.prefill(prompt.data(), prompt.size(), logits_reprefill); |
|
|
| DeviceBuffer batch_logits; |
| r.prefill(seq_tokens.data(), 4, batch_logits); |
| |
| |
| |
|
|
| std::vector<uint16_t> batch_last(V); |
| if (is_master) ACL_CHECK(aclrtMemcpy(batch_last.data(), V*2, batch_logits.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
|
|
| if (is_master) { |
| printf("\n=== Batched vs Sequential Decode Correctness ===\n"); |
| double l2d=0, l2r=0, maxd=0; |
| for (int i = 0; i < V; i++) { |
| float a = bf16_to_float(batch_last[i]), b = bf16_to_float(seq_logits[3][i]); |
| l2d += (a-b)*(a-b); l2r += b*b; |
| if (std::abs(a-b) > maxd) maxd = std::abs(a-b); |
| } |
| double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); |
| printf("Last-position logits:\n"); |
| printf(" seq[3] argmax = "); { |
| int b = 0; float bv = bf16_to_float(seq_logits[3][0]); |
| for (int i = 1; i < V; i++) if (bf16_to_float(seq_logits[3][i]) > bv) { bv = bf16_to_float(seq_logits[3][i]); b = i; } |
| printf("%d (%.3f)\n", b, bv); |
| } |
| printf(" batch argmax = "); { |
| int b = 0; float bv = bf16_to_float(batch_last[0]); |
| for (int i = 1; i < V; i++) if (bf16_to_float(batch_last[i]) > bv) { bv = bf16_to_float(batch_last[i]); b = i; } |
| printf("%d (%.3f)\n", b, bv); |
| } |
| printf(" rel=%.4e max=%.4f\n", rel, maxd); |
| printf(" %s\n", rel < 5e-2 ? "PASS" : "FAIL (batch forward diverges from sequential)"); |
| printf("\nNote: current Runner.prefill() returns ONLY last-position logits. For PLD\n"); |
| printf("we need all-position logits: requires extending prefill to optionally output\n"); |
| printf("[S, V] logits tensor.\n"); |
| } |
| return 0; |
| } |
|
|