| const PARAM_U32_COUNT = 12; |
| const PARAM_BUFFER_BYTES = PARAM_U32_COUNT * 4; |
| const shaderTextCache = new Map(); |
| const pipelineCache = new WeakMap(); |
|
|
| function align4(value) { |
| return (value + 3) & ~3; |
| } |
|
|
| function packedWeightToWords(packedWeight) { |
| const bytes = packedWeight instanceof Uint8Array ? packedWeight : new Uint8Array(packedWeight); |
| const padded = new Uint8Array(align4(bytes.byteLength)); |
| padded.set(bytes); |
| return new Uint32Array(padded.buffer); |
| } |
|
|
| function createStorageBuffer(device, data, usage = GPUBufferUsage.STORAGE) { |
| const source = ArrayBuffer.isView(data) ? data : new Uint8Array(data); |
| const buffer = device.createBuffer({ |
| size: align4(source.byteLength), |
| usage: usage | GPUBufferUsage.COPY_DST, |
| }); |
| device.queue.writeBuffer(buffer, 0, source.buffer, source.byteOffset, source.byteLength); |
| return buffer; |
| } |
|
|
| function createOutputBuffer(device, byteLength) { |
| return device.createBuffer({ |
| size: align4(byteLength), |
| usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, |
| }); |
| } |
|
|
| function createReadbackBuffer(device, byteLength) { |
| return device.createBuffer({ |
| size: align4(byteLength), |
| usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, |
| }); |
| } |
|
|
| function normalizeLayout(layoutHeader) { |
| if (!layoutHeader || layoutHeader.length < 13) { |
| throw new Error("BitNet layout_header must contain at least 13 entries"); |
| } |
| const header = Array.from(layoutHeader, Number); |
| if (header[0] !== 1 || header[1] !== 16 || header[2] !== 32 || header[9] !== 1) { |
| throw new Error("Unsupported BitNet browser layout; expected v1 16x32 interleave mode 1"); |
| } |
| return { |
| logicalOut: header[3], |
| logicalIn: header[4], |
| paddedOut: header[5], |
| paddedIn: header[6], |
| scaleGranularity: header[7], |
| scaleGroupSize: header[8], |
| segmentCount: header[11], |
| }; |
| } |
|
|
| function resolveUrl(path, baseUrl) { |
| return new URL(path, baseUrl).toString(); |
| } |
|
|
| function sleep(ms) { |
| return new Promise((resolve) => setTimeout(resolve, ms)); |
| } |
|
|
| async function fetchWithRetry(url, options = {}) { |
| const attempts = Math.max(1, Number(options.attempts || 5)); |
| let lastError = null; |
| for (let attempt = 0; attempt < attempts; attempt += 1) { |
| try { |
| const response = await fetch(url); |
| if (response.ok) return response; |
| if (response.status < 500 && response.status !== 408 && response.status !== 429) { |
| throw new Error(`failed to fetch ${url}: ${response.status}`); |
| } |
| lastError = new Error(`failed to fetch ${url}: ${response.status}`); |
| } catch (error) { |
| lastError = error; |
| } |
| if (attempt < attempts - 1) { |
| await sleep(Math.min(2000, 150 * 2 ** attempt)); |
| } |
| } |
| throw lastError || new Error(`failed to fetch ${url}`); |
| } |
|
|
| async function fetchJson(url) { |
| const response = await fetchWithRetry(url); |
| if (!response.ok) { |
| throw new Error(`failed to fetch ${url}: ${response.status}`); |
| } |
| return response.json(); |
| } |
|
|
| async function fetchText(url) { |
| const response = await fetchWithRetry(url); |
| if (!response.ok) { |
| throw new Error(`failed to fetch ${url}: ${response.status}`); |
| } |
| return response.text(); |
| } |
|
|
| async function fetchTextCached(url) { |
| if (!shaderTextCache.has(url)) { |
| shaderTextCache.set(url, fetchText(url)); |
| } |
| return shaderTextCache.get(url); |
| } |
|
|
| async function getBitNetPipeline(device, shaderCode, cacheKey) { |
| let deviceCache = pipelineCache.get(device); |
| if (!deviceCache) { |
| deviceCache = new Map(); |
| pipelineCache.set(device, deviceCache); |
| } |
| if (!deviceCache.has(cacheKey)) { |
| deviceCache.set(cacheKey, (async () => { |
| const module = device.createShaderModule({ code: shaderCode }); |
| const descriptor = { |
| layout: "auto", |
| compute: { module, entryPoint: "bitnet_linear_main" }, |
| }; |
| const pipeline = typeof device.createComputePipelineAsync === "function" |
| ? await device.createComputePipelineAsync(descriptor) |
| : device.createComputePipeline(descriptor); |
| return { module, pipeline }; |
| })()); |
| } |
| return deviceCache.get(cacheKey); |
| } |
|
|
| async function fetchTensor(entry, baseUrl, TypedArray) { |
| const url = resolveUrl(entry.path, baseUrl); |
| const response = await fetchWithRetry(url); |
| if (!response.ok) { |
| throw new Error(`failed to fetch ${entry.path}: ${response.status}`); |
| } |
| return new TypedArray(await response.arrayBuffer()); |
| } |
|
|
| function tensorType(entry) { |
| if (entry.dtype === "uint8") { |
| return Uint8Array; |
| } |
| if (entry.dtype === "int32") { |
| return Int32Array; |
| } |
| if (entry.dtype === "float32") { |
| return Float32Array; |
| } |
| throw new Error(`unsupported tensor dtype: ${entry.dtype}`); |
| } |
|
|
| export async function createBitNetWebGPUDevice() { |
| if (!globalThis.navigator?.gpu) { |
| throw new Error("WebGPU is not available in this browser"); |
| } |
| const adapter = await navigator.gpu.requestAdapter(); |
| if (!adapter) { |
| throw new Error("WebGPU adapter request failed"); |
| } |
| const device = await adapter.requestDevice(); |
| return { adapter, device }; |
| } |
|
|
| export class BitNetLinearWebGPU { |
| constructor(device, bundle) { |
| this.device = device; |
| this.layout = normalizeLayout(bundle.layoutHeader); |
| this.hasBias = bundle.bias != null; |
| this.inputQuantMode = bundle.inputQuantMode ?? 0; |
| this.inputQuantBits = bundle.inputQuantBits ?? 8; |
| this.inputScaleRows = bundle.inputScaleRows ?? 1; |
|
|
| if (!bundle.shaderCode && !bundle.pipeline) { |
| throw new Error("BitNetLinearWebGPU requires shaderCode or pipeline; use fromManifestLayer() or fromManifestUrl()"); |
| } |
| if (bundle.pipeline) { |
| this.module = bundle.module || null; |
| this.pipeline = bundle.pipeline; |
| } else { |
| this.module = device.createShaderModule({ code: bundle.shaderCode }); |
| this.pipeline = device.createComputePipeline({ |
| layout: "auto", |
| compute: { module: this.module, entryPoint: "bitnet_linear_main" }, |
| }); |
| } |
|
|
| this.packedWeightBuffer = createStorageBuffer(device, packedWeightToWords(bundle.packedWeight)); |
| this.scaleBuffer = createStorageBuffer(device, new Float32Array(bundle.scaleValues)); |
| this.segmentOffsetBuffer = createStorageBuffer(device, new Uint32Array(bundle.segmentOffsets)); |
| this.biasBuffer = createStorageBuffer( |
| device, |
| this.hasBias ? new Float32Array(bundle.bias) : new Float32Array([0]), |
| ); |
| this.inputScaleBuffer = createStorageBuffer( |
| device, |
| bundle.inputScales ? new Float32Array(bundle.inputScales) : new Float32Array([1]), |
| ); |
| this.paramsBuffer = device.createBuffer({ |
| size: PARAM_BUFFER_BYTES, |
| usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, |
| }); |
| this.runCache = new Map(); |
| } |
|
|
| static async fromManifestLayer(device, manifest, layer, manifestUrl, options = {}) { |
| const progress = typeof options.progress === "function" ? options.progress : () => {}; |
| const index = Number(options.index || 0); |
| const total = Number(options.total || 0); |
| const name = String(options.name || layer.name || "layer"); |
| const label = total ? `${index}/${total}: ${name}` : name; |
| const baseUrl = new URL(".", manifestUrl).toString(); |
| const shaderUrl = resolveUrl(manifest.runtime.files.wgsl, baseUrl); |
| const runtimeBaseUrl = resolveUrl(".", shaderUrl); |
| progress({ phase: "layer_shader", index, total, name, message: `Loading shader for BitNet layer ${label}` }); |
| const shaderCode = options.shaderCode || await fetchTextCached(shaderUrl); |
| progress({ phase: "layer_pipeline", index, total, name, message: `Preparing WebGPU pipeline for BitNet layer ${label}` }); |
| const pipelineBundle = options.pipeline |
| ? { module: options.module || null, pipeline: options.pipeline } |
| : await getBitNetPipeline(device, shaderCode, shaderUrl); |
| const tensors = layer.tensors; |
| const layersBaseUrl = resolveUrl("layers/", baseUrl); |
| progress({ phase: "layer_tensors", index, total, name, message: `Loading tensors for BitNet layer ${label}` }); |
| const [packedWeight, scaleValues, segmentOffsets, bias, inputScales] = await Promise.all([ |
| fetchTensor(tensors.packed_weight, layersBaseUrl, Uint8Array), |
| fetchTensor(tensors.scale_values, layersBaseUrl, Float32Array), |
| fetchTensor(tensors.segment_offsets, layersBaseUrl, Int32Array), |
| tensors.bias ? fetchTensor(tensors.bias, layersBaseUrl, Float32Array) : Promise.resolve(null), |
| fetchTensor(tensors.act_scale, layersBaseUrl, tensorType(tensors.act_scale)), |
| ]); |
| progress({ phase: "layer_upload", index, total, name, message: `Uploading BitNet layer ${label}` }); |
| return new BitNetLinearWebGPU(device, { |
| shaderCode, |
| module: pipelineBundle.module, |
| pipeline: pipelineBundle.pipeline, |
| layoutHeader: layer.layout_header, |
| packedWeight, |
| scaleValues, |
| segmentOffsets, |
| bias, |
| inputScales, |
| inputQuantMode: layer.act_quant_mode === "none" ? 0 : 1, |
| inputQuantBits: layer.act_quant_bits, |
| inputScaleRows: layer.act_quant_mode === "static_int8" ? 1 : 1, |
| runtimeBaseUrl, |
| }); |
| } |
|
|
| static async fromManifestUrl(device, manifestUrl, layerName) { |
| const manifest = await fetchJson(manifestUrl); |
| const layer = manifest.layers.find((candidate) => candidate.name === layerName); |
| if (!layer) { |
| throw new Error(`BitNet layer not found in manifest: ${layerName}`); |
| } |
| return BitNetLinearWebGPU.fromManifestLayer(device, manifest, layer, manifestUrl); |
| } |
|
|
| async run(input, rows = 1) { |
| const x = input instanceof Float32Array ? input : new Float32Array(input); |
| if (x.length !== rows * this.layout.logicalIn) { |
| throw new Error(`BitNet input length mismatch: got ${x.length}, expected ${rows * this.layout.logicalIn}`); |
| } |
|
|
| const outputLength = rows * this.layout.logicalOut; |
| const inputBytes = x.byteLength; |
| const outputBytes = outputLength * Float32Array.BYTES_PER_ELEMENT; |
| const cacheKey = `${rows}:${this.layout.logicalIn}:${this.layout.logicalOut}`; |
| let cache = this.runCache.get(cacheKey); |
| if (!cache) { |
| const inputBuffer = this.device.createBuffer({ |
| size: align4(inputBytes), |
| usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, |
| }); |
| const outputBuffer = createOutputBuffer(this.device, outputBytes); |
| const readbackBuffer = createReadbackBuffer(this.device, outputBytes); |
| const bindGroup = this.device.createBindGroup({ |
| layout: this.pipeline.getBindGroupLayout(0), |
| entries: [ |
| { binding: 0, resource: { buffer: inputBuffer } }, |
| { binding: 1, resource: { buffer: this.packedWeightBuffer } }, |
| { binding: 2, resource: { buffer: this.scaleBuffer } }, |
| { binding: 3, resource: { buffer: this.segmentOffsetBuffer } }, |
| { binding: 4, resource: { buffer: this.biasBuffer } }, |
| { binding: 5, resource: { buffer: this.inputScaleBuffer } }, |
| { binding: 6, resource: { buffer: outputBuffer } }, |
| { binding: 7, resource: { buffer: this.paramsBuffer } }, |
| ], |
| }); |
| cache = { inputBuffer, outputBuffer, readbackBuffer, bindGroup }; |
| this.runCache.set(cacheKey, cache); |
| } |
| this.device.queue.writeBuffer(cache.inputBuffer, 0, x.buffer, x.byteOffset, x.byteLength); |
|
|
| const params = new Uint32Array([ |
| rows, |
| this.layout.logicalIn, |
| this.layout.logicalOut, |
| this.layout.paddedIn, |
| this.layout.scaleGranularity, |
| this.layout.scaleGroupSize, |
| this.layout.segmentCount, |
| this.hasBias ? 1 : 0, |
| this.inputQuantMode, |
| this.inputQuantBits, |
| this.inputScaleRows, |
| 0, |
| ]); |
| this.device.queue.writeBuffer(this.paramsBuffer, 0, params); |
|
|
| const encoder = this.device.createCommandEncoder(); |
| const pass = encoder.beginComputePass(); |
| pass.setPipeline(this.pipeline); |
| pass.setBindGroup(0, cache.bindGroup); |
| pass.dispatchWorkgroups(Math.ceil(this.layout.logicalOut / 8), Math.ceil(rows / 8), 1); |
| pass.end(); |
| encoder.copyBufferToBuffer(cache.outputBuffer, 0, cache.readbackBuffer, 0, outputBytes); |
| this.device.queue.submit([encoder.finish()]); |
|
|
| await cache.readbackBuffer.mapAsync(GPUMapMode.READ); |
| const mapped = cache.readbackBuffer.getMappedRange(); |
| const result = new Float32Array(mapped.slice(0)); |
| cache.readbackBuffer.unmap(); |
| return result; |
| } |
| } |
|
|