PeytonT's picture
Update AgentKernel Lite browser BitNet to v11 recommendation answer model
fe55445 verified
const PARAM_U32_COUNT = 12;
const PARAM_BUFFER_BYTES = PARAM_U32_COUNT * 4;
const shaderTextCache = new Map();
const pipelineCache = new WeakMap();
function align4(value) {
return (value + 3) & ~3;
}
function packedWeightToWords(packedWeight) {
const bytes = packedWeight instanceof Uint8Array ? packedWeight : new Uint8Array(packedWeight);
const padded = new Uint8Array(align4(bytes.byteLength));
padded.set(bytes);
return new Uint32Array(padded.buffer);
}
function createStorageBuffer(device, data, usage = GPUBufferUsage.STORAGE) {
const source = ArrayBuffer.isView(data) ? data : new Uint8Array(data);
const buffer = device.createBuffer({
size: align4(source.byteLength),
usage: usage | GPUBufferUsage.COPY_DST,
});
device.queue.writeBuffer(buffer, 0, source.buffer, source.byteOffset, source.byteLength);
return buffer;
}
function createOutputBuffer(device, byteLength) {
return device.createBuffer({
size: align4(byteLength),
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});
}
function createReadbackBuffer(device, byteLength) {
return device.createBuffer({
size: align4(byteLength),
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
});
}
function normalizeLayout(layoutHeader) {
if (!layoutHeader || layoutHeader.length < 13) {
throw new Error("BitNet layout_header must contain at least 13 entries");
}
const header = Array.from(layoutHeader, Number);
if (header[0] !== 1 || header[1] !== 16 || header[2] !== 32 || header[9] !== 1) {
throw new Error("Unsupported BitNet browser layout; expected v1 16x32 interleave mode 1");
}
return {
logicalOut: header[3],
logicalIn: header[4],
paddedOut: header[5],
paddedIn: header[6],
scaleGranularity: header[7],
scaleGroupSize: header[8],
segmentCount: header[11],
};
}
function resolveUrl(path, baseUrl) {
return new URL(path, baseUrl).toString();
}
function sleep(ms) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
async function fetchWithRetry(url, options = {}) {
const attempts = Math.max(1, Number(options.attempts || 5));
let lastError = null;
for (let attempt = 0; attempt < attempts; attempt += 1) {
try {
const response = await fetch(url);
if (response.ok) return response;
if (response.status < 500 && response.status !== 408 && response.status !== 429) {
throw new Error(`failed to fetch ${url}: ${response.status}`);
}
lastError = new Error(`failed to fetch ${url}: ${response.status}`);
} catch (error) {
lastError = error;
}
if (attempt < attempts - 1) {
await sleep(Math.min(2000, 150 * 2 ** attempt));
}
}
throw lastError || new Error(`failed to fetch ${url}`);
}
async function fetchJson(url) {
const response = await fetchWithRetry(url);
if (!response.ok) {
throw new Error(`failed to fetch ${url}: ${response.status}`);
}
return response.json();
}
async function fetchText(url) {
const response = await fetchWithRetry(url);
if (!response.ok) {
throw new Error(`failed to fetch ${url}: ${response.status}`);
}
return response.text();
}
async function fetchTextCached(url) {
if (!shaderTextCache.has(url)) {
shaderTextCache.set(url, fetchText(url));
}
return shaderTextCache.get(url);
}
async function getBitNetPipeline(device, shaderCode, cacheKey) {
let deviceCache = pipelineCache.get(device);
if (!deviceCache) {
deviceCache = new Map();
pipelineCache.set(device, deviceCache);
}
if (!deviceCache.has(cacheKey)) {
deviceCache.set(cacheKey, (async () => {
const module = device.createShaderModule({ code: shaderCode });
const descriptor = {
layout: "auto",
compute: { module, entryPoint: "bitnet_linear_main" },
};
const pipeline = typeof device.createComputePipelineAsync === "function"
? await device.createComputePipelineAsync(descriptor)
: device.createComputePipeline(descriptor);
return { module, pipeline };
})());
}
return deviceCache.get(cacheKey);
}
async function fetchTensor(entry, baseUrl, TypedArray) {
const url = resolveUrl(entry.path, baseUrl);
const response = await fetchWithRetry(url);
if (!response.ok) {
throw new Error(`failed to fetch ${entry.path}: ${response.status}`);
}
return new TypedArray(await response.arrayBuffer());
}
function tensorType(entry) {
if (entry.dtype === "uint8") {
return Uint8Array;
}
if (entry.dtype === "int32") {
return Int32Array;
}
if (entry.dtype === "float32") {
return Float32Array;
}
throw new Error(`unsupported tensor dtype: ${entry.dtype}`);
}
export async function createBitNetWebGPUDevice() {
if (!globalThis.navigator?.gpu) {
throw new Error("WebGPU is not available in this browser");
}
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
throw new Error("WebGPU adapter request failed");
}
const device = await adapter.requestDevice();
return { adapter, device };
}
export class BitNetLinearWebGPU {
constructor(device, bundle) {
this.device = device;
this.layout = normalizeLayout(bundle.layoutHeader);
this.hasBias = bundle.bias != null;
this.inputQuantMode = bundle.inputQuantMode ?? 0;
this.inputQuantBits = bundle.inputQuantBits ?? 8;
this.inputScaleRows = bundle.inputScaleRows ?? 1;
if (!bundle.shaderCode && !bundle.pipeline) {
throw new Error("BitNetLinearWebGPU requires shaderCode or pipeline; use fromManifestLayer() or fromManifestUrl()");
}
if (bundle.pipeline) {
this.module = bundle.module || null;
this.pipeline = bundle.pipeline;
} else {
this.module = device.createShaderModule({ code: bundle.shaderCode });
this.pipeline = device.createComputePipeline({
layout: "auto",
compute: { module: this.module, entryPoint: "bitnet_linear_main" },
});
}
this.packedWeightBuffer = createStorageBuffer(device, packedWeightToWords(bundle.packedWeight));
this.scaleBuffer = createStorageBuffer(device, new Float32Array(bundle.scaleValues));
this.segmentOffsetBuffer = createStorageBuffer(device, new Uint32Array(bundle.segmentOffsets));
this.biasBuffer = createStorageBuffer(
device,
this.hasBias ? new Float32Array(bundle.bias) : new Float32Array([0]),
);
this.inputScaleBuffer = createStorageBuffer(
device,
bundle.inputScales ? new Float32Array(bundle.inputScales) : new Float32Array([1]),
);
this.paramsBuffer = device.createBuffer({
size: PARAM_BUFFER_BYTES,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
this.runCache = new Map();
}
static async fromManifestLayer(device, manifest, layer, manifestUrl, options = {}) {
const progress = typeof options.progress === "function" ? options.progress : () => {};
const index = Number(options.index || 0);
const total = Number(options.total || 0);
const name = String(options.name || layer.name || "layer");
const label = total ? `${index}/${total}: ${name}` : name;
const baseUrl = new URL(".", manifestUrl).toString();
const shaderUrl = resolveUrl(manifest.runtime.files.wgsl, baseUrl);
const runtimeBaseUrl = resolveUrl(".", shaderUrl);
progress({ phase: "layer_shader", index, total, name, message: `Loading shader for BitNet layer ${label}` });
const shaderCode = options.shaderCode || await fetchTextCached(shaderUrl);
progress({ phase: "layer_pipeline", index, total, name, message: `Preparing WebGPU pipeline for BitNet layer ${label}` });
const pipelineBundle = options.pipeline
? { module: options.module || null, pipeline: options.pipeline }
: await getBitNetPipeline(device, shaderCode, shaderUrl);
const tensors = layer.tensors;
const layersBaseUrl = resolveUrl("layers/", baseUrl);
progress({ phase: "layer_tensors", index, total, name, message: `Loading tensors for BitNet layer ${label}` });
const [packedWeight, scaleValues, segmentOffsets, bias, inputScales] = await Promise.all([
fetchTensor(tensors.packed_weight, layersBaseUrl, Uint8Array),
fetchTensor(tensors.scale_values, layersBaseUrl, Float32Array),
fetchTensor(tensors.segment_offsets, layersBaseUrl, Int32Array),
tensors.bias ? fetchTensor(tensors.bias, layersBaseUrl, Float32Array) : Promise.resolve(null),
fetchTensor(tensors.act_scale, layersBaseUrl, tensorType(tensors.act_scale)),
]);
progress({ phase: "layer_upload", index, total, name, message: `Uploading BitNet layer ${label}` });
return new BitNetLinearWebGPU(device, {
shaderCode,
module: pipelineBundle.module,
pipeline: pipelineBundle.pipeline,
layoutHeader: layer.layout_header,
packedWeight,
scaleValues,
segmentOffsets,
bias,
inputScales,
inputQuantMode: layer.act_quant_mode === "none" ? 0 : 1,
inputQuantBits: layer.act_quant_bits,
inputScaleRows: layer.act_quant_mode === "static_int8" ? 1 : 1,
runtimeBaseUrl,
});
}
static async fromManifestUrl(device, manifestUrl, layerName) {
const manifest = await fetchJson(manifestUrl);
const layer = manifest.layers.find((candidate) => candidate.name === layerName);
if (!layer) {
throw new Error(`BitNet layer not found in manifest: ${layerName}`);
}
return BitNetLinearWebGPU.fromManifestLayer(device, manifest, layer, manifestUrl);
}
async run(input, rows = 1) {
const x = input instanceof Float32Array ? input : new Float32Array(input);
if (x.length !== rows * this.layout.logicalIn) {
throw new Error(`BitNet input length mismatch: got ${x.length}, expected ${rows * this.layout.logicalIn}`);
}
const outputLength = rows * this.layout.logicalOut;
const inputBytes = x.byteLength;
const outputBytes = outputLength * Float32Array.BYTES_PER_ELEMENT;
const cacheKey = `${rows}:${this.layout.logicalIn}:${this.layout.logicalOut}`;
let cache = this.runCache.get(cacheKey);
if (!cache) {
const inputBuffer = this.device.createBuffer({
size: align4(inputBytes),
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
const outputBuffer = createOutputBuffer(this.device, outputBytes);
const readbackBuffer = createReadbackBuffer(this.device, outputBytes);
const bindGroup = this.device.createBindGroup({
layout: this.pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: inputBuffer } },
{ binding: 1, resource: { buffer: this.packedWeightBuffer } },
{ binding: 2, resource: { buffer: this.scaleBuffer } },
{ binding: 3, resource: { buffer: this.segmentOffsetBuffer } },
{ binding: 4, resource: { buffer: this.biasBuffer } },
{ binding: 5, resource: { buffer: this.inputScaleBuffer } },
{ binding: 6, resource: { buffer: outputBuffer } },
{ binding: 7, resource: { buffer: this.paramsBuffer } },
],
});
cache = { inputBuffer, outputBuffer, readbackBuffer, bindGroup };
this.runCache.set(cacheKey, cache);
}
this.device.queue.writeBuffer(cache.inputBuffer, 0, x.buffer, x.byteOffset, x.byteLength);
const params = new Uint32Array([
rows,
this.layout.logicalIn,
this.layout.logicalOut,
this.layout.paddedIn,
this.layout.scaleGranularity,
this.layout.scaleGroupSize,
this.layout.segmentCount,
this.hasBias ? 1 : 0,
this.inputQuantMode,
this.inputQuantBits,
this.inputScaleRows,
0,
]);
this.device.queue.writeBuffer(this.paramsBuffer, 0, params);
const encoder = this.device.createCommandEncoder();
const pass = encoder.beginComputePass();
pass.setPipeline(this.pipeline);
pass.setBindGroup(0, cache.bindGroup);
pass.dispatchWorkgroups(Math.ceil(this.layout.logicalOut / 8), Math.ceil(rows / 8), 1);
pass.end();
encoder.copyBufferToBuffer(cache.outputBuffer, 0, cache.readbackBuffer, 0, outputBytes);
this.device.queue.submit([encoder.finish()]);
await cache.readbackBuffer.mapAsync(GPUMapMode.READ);
const mapped = cache.readbackBuffer.getMappedRange();
const result = new Float32Array(mapped.slice(0));
cache.readbackBuffer.unmap();
return result;
}
}