needle-playground / src /runtime.ts
shreyask's picture
Upload folder using huggingface_hub
814c07e verified
import * as ort from 'onnxruntime-web/wasm';
import { ENCODER_URL, DECODER_URL } from './config';
// Point onnxruntime-web at the WASM files Vite copies into /ort/ via vite.config.ts.
ort.env.wasm.wasmPaths = '/ort/';
// Disable multi-threading so ORT doesn't try to dynamically import the .mjs
// worker shim (which Vite's dev server blocks for files served from /public).
ort.env.wasm.numThreads = 1;
export interface NeedleSessions {
encoder: ort.InferenceSession;
decoder: ort.InferenceSession;
}
export async function loadSessions(onProgress?: (m: string) => void): Promise<NeedleSessions> {
onProgress?.('downloading encoder…');
const encoder = await ort.InferenceSession.create(ENCODER_URL, { executionProviders: ['wasm'] });
onProgress?.('downloading decoder…');
const decoder = await ort.InferenceSession.create(DECODER_URL, { executionProviders: ['wasm'] });
return { encoder, decoder };
}
export async function runEncoder(sess: ort.InferenceSession, inputIds: number[]): Promise<ort.Tensor> {
const ids = new ort.Tensor(
'int64',
BigInt64Array.from(inputIds.map(BigInt)),
[1, inputIds.length],
);
const out = await sess.run({ input_ids: ids });
return out.encoder_out;
}
export interface DecoderStepOut {
logits: ort.Tensor;
presentSelfKv: ort.Tensor;
}
export async function stepDecoder(
sess: ort.InferenceSession,
decoderInputId: number,
encoderOut: ort.Tensor,
pastSelfKv: ort.Tensor,
): Promise<DecoderStepOut> {
const dec = new ort.Tensor('int64', BigInt64Array.from([BigInt(decoderInputId)]), [1, 1]);
const out = await sess.run({
decoder_input_ids: dec,
encoder_out: encoderOut,
past_self_kv: pastSelfKv,
});
return { logits: out.logits, presentSelfKv: out.present_self_kv };
}
/**
* Initial empty past_self_kv tensor for step 0. Shape (8, 2, 1, 4, 0, 64) per the
* ONNX export's dynamic axes (layers, k|v, batch, kv_heads, past_seq=0, head_dim).
*/
export function initialPastKv(): ort.Tensor {
return new ort.Tensor('float32', new Float32Array(0), [8, 2, 1, 4, 0, 64]);
}