llm_mutil_npu / tests /test_batch_correctness.cpp
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
// test_batch_correctness.cpp — verify that forward with S>1 at past_len>0 produces the
// same logits at each position as sequential S=1 decodes.
//
// This is the foundation for speculative decoding / PLD: the main model must predict logits
// for each of K candidate positions in one batched forward pass matching sequential behavior.
#include "runner.h"
#include <cstdio>
#include <cstring>
#include <vector>
#include <cmath>
static float bf16_to_float(uint16_t x) {
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
}
int main() {
const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
int tp_rank = 0, tp_size = 1;
if (const char* v = std::getenv("TP_RANK")) tp_rank = std::atoi(v);
if (const char* v = std::getenv("TP_SIZE")) tp_size = std::atoi(v);
bool is_master = tp_rank == 0;
Runner r;
if (!r.init(model_dir, tp_size, tp_rank, 94, 512)) return 1;
const int64_t V = r.cfg().vocab_size;
// Prefix
std::vector<int32_t> prompt = {785, 6722, 315, 9625, 374};
DeviceBuffer logits0;
r.prefill(prompt.data(), prompt.size(), logits0);
std::vector<uint16_t> h_last0(V);
if (is_master) ACL_CHECK(aclrtMemcpy(h_last0.data(), V*2, logits0.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
int next0 = 0;
if (is_master) {
float best = -1e30; for (int i = 0; i < V; i++) { float v = bf16_to_float(h_last0[i]); if (v > best) { best = v; next0 = i; } }
}
// Broadcast next0 (simple: let rank 0 decide and non-master ranks independently too)
int32_t token_seq[4];
if (is_master) token_seq[0] = next0;
// --- Path A: sequential S=1 decode × 4 times ---
std::vector<std::vector<uint16_t>> seq_logits(4);
for (int i = 0; i < 4; i++) seq_logits[i].resize(V);
// first decode: takes prompt's last logit argmax
// Here we need identical approach on all ranks. Use random token id for consistency.
std::vector<int32_t> seq_tokens = {next0, 100, 200, 300}; // deterministic for test
for (int i = 0; i < 4; i++) {
DeviceBuffer out;
r.decode(seq_tokens[i], out);
if (is_master) ACL_CHECK(aclrtMemcpy(seq_logits[i].data(), V*2, out.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
}
int64_t past_after_seq = r.past_len();
// --- Path B: reset, re-prefill, then ONE batch forward with S=4 ---
r.reset_cache();
DeviceBuffer logits_reprefill;
r.prefill(prompt.data(), prompt.size(), logits_reprefill);
DeviceBuffer batch_logits;
r.prefill(seq_tokens.data(), 4, batch_logits);
// prefill returns logits for LAST position only (S=4 gives [1, V], not [4, V]).
// Hmm — that's a limitation. To do PLD we need logits for all 4 positions.
// For now, just compare the LAST one (position 4 after prefix).
std::vector<uint16_t> batch_last(V);
if (is_master) ACL_CHECK(aclrtMemcpy(batch_last.data(), V*2, batch_logits.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
if (is_master) {
printf("\n=== Batched vs Sequential Decode Correctness ===\n");
double l2d=0, l2r=0, maxd=0;
for (int i = 0; i < V; i++) {
float a = bf16_to_float(batch_last[i]), b = bf16_to_float(seq_logits[3][i]);
l2d += (a-b)*(a-b); l2r += b*b;
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
}
double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
printf("Last-position logits:\n");
printf(" seq[3] argmax = "); {
int b = 0; float bv = bf16_to_float(seq_logits[3][0]);
for (int i = 1; i < V; i++) if (bf16_to_float(seq_logits[3][i]) > bv) { bv = bf16_to_float(seq_logits[3][i]); b = i; }
printf("%d (%.3f)\n", b, bv);
}
printf(" batch argmax = "); {
int b = 0; float bv = bf16_to_float(batch_last[0]);
for (int i = 1; i < V; i++) if (bf16_to_float(batch_last[i]) > bv) { bv = bf16_to_float(batch_last[i]); b = i; }
printf("%d (%.3f)\n", b, bv);
}
printf(" rel=%.4e max=%.4f\n", rel, maxd);
printf(" %s\n", rel < 5e-2 ? "PASS" : "FAIL (batch forward diverges from sequential)");
printf("\nNote: current Runner.prefill() returns ONLY last-position logits. For PLD\n");
printf("we need all-position logits: requires extending prefill to optionally output\n");
printf("[S, V] logits tensor.\n");
}
return 0;
}