import { env, AutoTokenizer, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers'; import { getModelJSON, getModelFile } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2/src/utils/hub.js"; import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.0/dist/ort.webgpu.mjs"; const EXAMPLE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"; const INPUT_IMAGE_SIZE = [960, 960]; const HEIGHT_FACTOR = 10; const WIDTH_FACTOR = 10; const IMAGE_EMBED_SIZE = WIDTH_FACTOR * HEIGHT_FACTOR; const MAX_SEQ_LENGTH = 1024; const BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct"; const ONNX_MODEL = "pdufour/Qwen2-VL-2B-Instruct-ONNX-Q4-F16"; const QUANT = "q4f16"; const MAX_SINGLE_CHAT_LENGTH = 10; // UI Elements const status = document.getElementById('status'); const fileUpload = document.getElementById('upload'); const imageContainer = document.getElementById('container'); const example = document.getElementById('example'); let ortSessionA, ortSessionB, ortSessionC; async function initializeSessions() { status.textContent = 'Loading model...'; ortSessionA = await ort.InferenceSession.create( await getModelFile(ONNX_MODEL, `onnx/QwenVL_A_${QUANT}.onnx`), { executionProviders: ["webgpu"] } ); console.log({ortSessionA}); ortSessionB = await ort.InferenceSession.create( await getModelFile(ONNX_MODEL, `onnx/QwenVL_B_${QUANT}.onnx`), { executionProviders: ["webgpu"] } ); console.log({ortSessionB}); ortSessionC = await ort.InferenceSession.create( await getModelFile(ONNX_MODEL, `onnx/QwenVL_C_${QUANT}.onnx`), { executionProviders: ["webgpu"] } ); console.log({ortSessionC}); status.textContent = 'Ready'; } // UI Event Handlers example.addEventListener('click', (e) => { e.preventDefault(); parse(EXAMPLE_URL, 'Describe this image.'); }); fileUpload.addEventListener('change', function(e) { const file = e.target.files[0]; if (!file) return; const reader = new FileReader(); reader.onload = e2 => parse(e2.target.result, ''); reader.readAsDataURL(file); }); async function parse(img, txt) { imageContainer.innerHTML = ''; imageContainer.style.backgroundImage = `url(${img})`; status.textContent = 'Analysing...'; const output = await imageTextToText(img, txt); status.textContent = output; } export async function imageTextToText( imagePath, query, vision = true ) { let ortSessionA, ortSessionB, ortSessionC, ortSessionD, ortSessionE; const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]); let position_ids; let num_decode = 0; let history_len = new Tensor("int64", new BigInt64Array([0n]), [1]); var pos_factor_v = BigInt(1 - IMAGE_EMBED_SIZE + WIDTH_FACTOR); let past_key_states = new ort.Tensor( "float16", new Uint16Array( config.num_hidden_layers * config.num_key_value_heads * MAX_SEQ_LENGTH * (config.hidden_size / config.num_attention_heads) ).fill(0), [ config.num_hidden_layers, config.num_key_value_heads, MAX_SEQ_LENGTH, config.hidden_size / config.num_attention_heads, ] ); let past_value_states = past_key_states; let attention_mask = new ort.Tensor( "float16", new Uint16Array([0xfbff]), [1] ); let pos_factor = new Tensor("float16", new Uint16Array([0]), [1]); logger.groupCollapsed("[TOKENIZATION] Processing prompt..."); const tokenizer = await AutoTokenizer.from_pretrained(BASE_MODEL); const prompt = `\n<|im_start|>user\n<|vision_start|><|vision_end|>${query}<|im_end|>\n<|im_start|>assistant\n`; const token = await tokenizer(prompt, { return_tensors: "pt", add_generation_prompt: false, tokenize: true, }).input_ids; const seq_length = token.dims[1]; let ids_len = new Tensor("int64", new BigInt64Array([BigInt(seq_length)]), [ 1, ]); let input_ids = new ort.Tensor( "int32", new Int32Array(MAX_SEQ_LENGTH).fill(0), [MAX_SEQ_LENGTH] ); input_ids.data.set(Array.from(token.data.slice(0, seq_length), Number)); const dummy = new ort.Tensor("int32", new Int32Array([0]), []); let { hidden_states } = await ortSessionB.run({ input_ids: input_ids, ids_len: ids_len, }); ({ position_ids } = await ortSessionC.run({ dummy: dummy, })); // Process image if (vision) { let image = await RawImage.fromURL(imagePath); image = await image.resize(INPUT_IMAGE_SIZE[0], INPUT_IMAGE_SIZE[1]); image = image.rgb(); image = image.toTensor("CHW"); image = image.to("float32"); image = image.div_(255.0); const pixel_values = image.unsqueeze(0); const { image_embed } = await ortSessionA.run({ pixel_values: pixel_values, }); ids_len = ids_len.add(BigInt(IMAGE_EMBED_SIZE)); const split_factor = new Tensor( "int32", new Int32Array([ MAX_SEQ_LENGTH - Number(ids_len.item()) - IMAGE_EMBED_SIZE, ]), [1] ); const ids_len_minus = new Tensor( "int32", new Int32Array([Number(ids_len.item()) - Number(prompt_head_len.item())]), [1] ); await ortSessionA.release(); ortSessionA = null; logger.log("session d create"); ortSessionD = await ort.InferenceSession.create( await getModelFile(ONNX_MODEL, `onnx/QwenVL_D_${QUANT}.onnx`), { executionProviders: ["webgpu"], } ); ({ hidden_states, position_ids } = await ortSessionD.run({ "hidden_states.1": hidden_states, image_embed, ids_len, ids_len_minus, split_factor, })); await ortSessionD.release(); ortSessionD = null; } let output = ''; while ( num_decode < MAX_SINGLE_CHAT_LENGTH && Number(history_len.data[0]) < MAX_SEQ_LENGTH ) { let token_id; if (!ortSessionE) { console.log("Create ortSessionE"); ortSessionE = await ort.InferenceSession.create( await getModelFile(ONNX_MODEL, `onnx/QwenVL_E_${QUANT}.onnx`), { executionProviders: ["wasm"], }, ); } ({ max_logit_ids: token_id, past_key_states: past_key_states, past_value_states: past_value_states, } = await ortSessionE.run({ hidden_states, attention_mask, "past_key_states.1": past_key_states, "past_value_states.1": past_value_states, history_len, ids_len, position_ids, pos_factor, })); if (token_id === 151643 || token_id === 151645) { break; } num_decode++; if (num_decode < 2) { history_len = history_len.add(BigInt(ids_len.data[0])); ids_len = new ort.Tensor("int64", new BigInt64Array([1n]), [1]); attention_mask = new ort.Tensor("float16", new Uint16Array([0]), [1]); if (vision) { pos_factor = new Tensor( "float16", new Uint16Array([int64ToFloat16(pos_factor_v + ids_len.data[0])]), [1] ); } else { pos_factor = new Tensor( "float16", new Uint16Array([int64ToFloat16(history_len.data[0] + BigInt(1))]), [1] ); } } else { history_len = history_len.add(BigInt(1)); pos_factor = pos_factor.map((v) => int64ToFloat16(float16ToInt64(v) + BigInt(1)) ); logger.tensor("Updated history_len", history_len); logger.tensor("Updated pos_factor", pos_factor); logger.groupEnd(); } (input_ids.data)[0] = Number(token_id.data[0]); const result_B = await ortSessionB.run({ input_ids: input_ids, ids_len: ids_len, }); hidden_states = result_B.hidden_states; if ( !Number.isInteger(token_id.data[0]) && !["bigint", "number"].includes(typeof token_id.data[0]) ) { throw new Error(`Token ID is not an integer`); } else { const decoded = tokenizer.decode([...token_id.data]) output += decoded; } } } await initializeSessions();