Spaces:
Running
Running
| import * as ort from 'onnxruntime-web/wasm'; | |
| import { ENCODER_URL, DECODER_URL } from './config'; | |
| // Point onnxruntime-web at the WASM files Vite copies into /ort/ via vite.config.ts. | |
| ort.env.wasm.wasmPaths = '/ort/'; | |
| // Disable multi-threading so ORT doesn't try to dynamically import the .mjs | |
| // worker shim (which Vite's dev server blocks for files served from /public). | |
| ort.env.wasm.numThreads = 1; | |
| export interface NeedleSessions { | |
| encoder: ort.InferenceSession; | |
| decoder: ort.InferenceSession; | |
| } | |
| export async function loadSessions(onProgress?: (m: string) => void): Promise<NeedleSessions> { | |
| onProgress?.('downloading encoder…'); | |
| const encoder = await ort.InferenceSession.create(ENCODER_URL, { executionProviders: ['wasm'] }); | |
| onProgress?.('downloading decoder…'); | |
| const decoder = await ort.InferenceSession.create(DECODER_URL, { executionProviders: ['wasm'] }); | |
| return { encoder, decoder }; | |
| } | |
| export async function runEncoder(sess: ort.InferenceSession, inputIds: number[]): Promise<ort.Tensor> { | |
| const ids = new ort.Tensor( | |
| 'int64', | |
| BigInt64Array.from(inputIds.map(BigInt)), | |
| [1, inputIds.length], | |
| ); | |
| const out = await sess.run({ input_ids: ids }); | |
| return out.encoder_out; | |
| } | |
| export interface DecoderStepOut { | |
| logits: ort.Tensor; | |
| presentSelfKv: ort.Tensor; | |
| } | |
| export async function stepDecoder( | |
| sess: ort.InferenceSession, | |
| decoderInputId: number, | |
| encoderOut: ort.Tensor, | |
| pastSelfKv: ort.Tensor, | |
| ): Promise<DecoderStepOut> { | |
| const dec = new ort.Tensor('int64', BigInt64Array.from([BigInt(decoderInputId)]), [1, 1]); | |
| const out = await sess.run({ | |
| decoder_input_ids: dec, | |
| encoder_out: encoderOut, | |
| past_self_kv: pastSelfKv, | |
| }); | |
| return { logits: out.logits, presentSelfKv: out.present_self_kv }; | |
| } | |
| /** | |
| * Initial empty past_self_kv tensor for step 0. Shape (8, 2, 1, 4, 0, 64) per the | |
| * ONNX export's dynamic axes (layers, k|v, batch, kv_heads, past_seq=0, head_dim). | |
| */ | |
| export function initialPastKv(): ort.Tensor { | |
| return new ort.Tensor('float32', new Float32Array(0), [8, 2, 1, 4, 0, 64]); | |
| } | |