|
import { env, AutoTokenizer, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers'; |
|
import { getModelJSON, getModelFile } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2/src/utils/hub.js"; |
|
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.0/dist/ort.webgpu.mjs"; |
|
|
|
const EXAMPLE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"; |
|
const INPUT_IMAGE_SIZE = [960, 960]; |
|
const HEIGHT_FACTOR = 10; |
|
const WIDTH_FACTOR = 10; |
|
const IMAGE_EMBED_SIZE = WIDTH_FACTOR * HEIGHT_FACTOR; |
|
const MAX_SEQ_LENGTH = 1024; |
|
const BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct"; |
|
const ONNX_MODEL = "pdufour/Qwen2-VL-2B-Instruct-ONNX-Q4-F16"; |
|
const QUANT = "q4f16"; |
|
const MAX_SINGLE_CHAT_LENGTH = 10; |
|
|
|
|
|
const exampleButton = document.getElementById('example'); |
|
const promptInput = document.querySelector('input[type="text"]'); |
|
const status = document.getElementById('status'); |
|
const imageContainer = document.getElementById('container'); |
|
const example = document.getElementById('example'); |
|
const thumb = document.getElementById('thumb'); |
|
const uploadInput = document.getElementById('upload'); |
|
const form = document.getElementById('form'); |
|
const output = document.getElementById('llm-output'); |
|
|
|
let ortSessionA, ortSessionB, ortSessionC, ortSessionD, ortSessionE; |
|
let config; |
|
let currentImage = ''; |
|
let currentQuery = ''; |
|
|
|
async function initializeSessions() { |
|
status.textContent = 'Loading model...'; |
|
|
|
ortSessionA = await ort.InferenceSession.create( |
|
await getModelFile(ONNX_MODEL, `onnx/QwenVL_A_${QUANT}.onnx`), |
|
{ executionProviders: ["webgpu"] } |
|
); |
|
|
|
ortSessionB = await ort.InferenceSession.create( |
|
await getModelFile(ONNX_MODEL, `onnx/QwenVL_B_${QUANT}.onnx`), |
|
{ executionProviders: ["webgpu"] } |
|
); |
|
|
|
ortSessionC = await ort.InferenceSession.create( |
|
await getModelFile(ONNX_MODEL, `onnx/QwenVL_C_${QUANT}.onnx`), |
|
{ executionProviders: ["webgpu"] } |
|
); |
|
|
|
ortSessionD = await ort.InferenceSession.create( |
|
await getModelFile(ONNX_MODEL, `onnx/QwenVL_D_${QUANT}.onnx`), |
|
{ |
|
executionProviders: ["webgpu"], |
|
} |
|
|
|
ortSessionE = await ort.InferenceSession.create( |
|
await getModelFile(ONNX_MODEL, `onnx/QwenVL_E_${QUANT}.onnx`), |
|
{ |
|
executionProviders: ["wasm"], |
|
}, |
|
); |
|
); |
|
|
|
|
|
config = (await getModelJSON(BASE_MODEL, "config.json")); |
|
|
|
status.textContent = 'Ready'; |
|
} |
|
|
|
export function int64ToFloat16(int64Value) { |
|
|
|
const float64Value = Number(int64Value); |
|
|
|
|
|
if (!isFinite(float64Value)) return float64Value > 0 ? 0x7c00 : 0xfc00; |
|
if (float64Value === 0) return 0; |
|
|
|
|
|
const sign = float64Value < 0 ? 1 : 0; |
|
const absValue = Math.abs(float64Value); |
|
const exponent = Math.floor(Math.log2(absValue)); |
|
const mantissa = absValue / Math.pow(2, exponent) - 1; |
|
|
|
|
|
const float16Exponent = exponent + 15; |
|
const float16Mantissa = Math.round(mantissa * 1024); |
|
|
|
|
|
if (float16Exponent <= 0) { |
|
|
|
return (sign << 15) | (float16Mantissa >> 1); |
|
} else if (float16Exponent >= 31) { |
|
|
|
return (sign << 15) | 0x7c00; |
|
} else { |
|
|
|
return (sign << 15) | (float16Exponent << 10) | (float16Mantissa & 0x3ff); |
|
} |
|
} |
|
|
|
export function float16ToInt64(float16Value) { |
|
|
|
const sign = (float16Value & 0x8000) >> 15; |
|
const exponent = (float16Value & 0x7c00) >> 10; |
|
const mantissa = float16Value & 0x03ff; |
|
|
|
|
|
if (exponent === 0 && mantissa === 0) return BigInt(0); |
|
if (exponent === 0x1f) return sign ? BigInt("-Infinity") : BigInt("Infinity"); |
|
|
|
|
|
let value; |
|
if (exponent === 0) { |
|
|
|
value = Math.pow(2, -14) * (mantissa / 1024); |
|
} else { |
|
|
|
value = Math.pow(2, exponent - 15) * (1 + mantissa / 1024); |
|
} |
|
|
|
|
|
value = sign ? -value : value; |
|
|
|
return BigInt(Math.round(value)); |
|
} |
|
|
|
|
|
async function handleQuery(imageUrl, query) { |
|
console.log('handleQuery', {imageUrl}, {query}); |
|
|
|
try { |
|
status.textContent = 'Analyzing...'; |
|
|
|
const result = await imageTextToText(imageUrl, query, (out) => { |
|
output.textContent = out; |
|
}); |
|
} catch (err) { |
|
status.textContent = 'Error processing request'; |
|
console.error(err); |
|
} |
|
} |
|
|
|
|
|
export async function imageTextToText( |
|
imagePath, |
|
query, |
|
cb, |
|
vision = true, |
|
) { |
|
|
|
const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]); |
|
|
|
let position_ids; |
|
let num_decode = 0; |
|
let history_len = new Tensor("int64", new BigInt64Array([0n]), [1]); |
|
|
|
var pos_factor_v = BigInt(1 - IMAGE_EMBED_SIZE + WIDTH_FACTOR); |
|
|
|
let past_key_states = new ort.Tensor( |
|
"float16", |
|
new Uint16Array( |
|
config.num_hidden_layers * |
|
config.num_key_value_heads * |
|
MAX_SEQ_LENGTH * |
|
(config.hidden_size / config.num_attention_heads) |
|
).fill(0), |
|
[ |
|
config.num_hidden_layers, |
|
config.num_key_value_heads, |
|
MAX_SEQ_LENGTH, |
|
config.hidden_size / config.num_attention_heads, |
|
] |
|
); |
|
|
|
let past_value_states = past_key_states; |
|
|
|
let attention_mask = new ort.Tensor( |
|
"float16", |
|
new Uint16Array([0xfbff]), |
|
[1] |
|
); |
|
|
|
let pos_factor = new Tensor("float16", new Uint16Array([0]), [1]); |
|
|
|
const tokenizer = await AutoTokenizer.from_pretrained(BASE_MODEL); |
|
const prompt = `\n<|im_start|>user\n<|vision_start|><|vision_end|>${query}<|im_end|>\n<|im_start|>assistant\n`; |
|
const token = await tokenizer(prompt, { |
|
return_tensors: "pt", |
|
add_generation_prompt: false, |
|
tokenize: true, |
|
}).input_ids; |
|
|
|
const seq_length = token.dims[1]; |
|
let ids_len = new Tensor("int64", new BigInt64Array([BigInt(seq_length)]), [ |
|
1, |
|
]); |
|
|
|
let input_ids = new ort.Tensor( |
|
"int32", |
|
new Int32Array(MAX_SEQ_LENGTH).fill(0), |
|
[MAX_SEQ_LENGTH] |
|
); |
|
|
|
input_ids.data.set(Array.from(token.data.slice(0, seq_length), Number)); |
|
|
|
const dummy = new ort.Tensor("int32", new Int32Array([0]), []); |
|
|
|
let { hidden_states } = await ortSessionB.run({ |
|
input_ids: input_ids, |
|
ids_len: ids_len, |
|
}); |
|
|
|
({ position_ids } = await ortSessionC.run({ |
|
dummy: dummy, |
|
})); |
|
|
|
|
|
if (vision) { |
|
let image = await RawImage.fromURL(imagePath); |
|
|
|
image = await image.resize(INPUT_IMAGE_SIZE[0], INPUT_IMAGE_SIZE[1]); |
|
|
|
image = image.rgb(); |
|
|
|
image = image.toTensor("CHW"); |
|
image = image.to("float32"); |
|
image = image.div_(255.0); |
|
const pixel_values = image.unsqueeze(0); |
|
|
|
const { image_embed } = await ortSessionA.run({ |
|
pixel_values: pixel_values, |
|
}); |
|
|
|
ids_len = ids_len.add(BigInt(IMAGE_EMBED_SIZE)); |
|
|
|
const split_factor = new Tensor( |
|
"int32", |
|
new Int32Array([ |
|
MAX_SEQ_LENGTH - Number(ids_len.item()) - IMAGE_EMBED_SIZE, |
|
]), |
|
[1] |
|
); |
|
|
|
const ids_len_minus = new Tensor( |
|
"int32", |
|
new Int32Array([Number(ids_len.item()) - Number(prompt_head_len.item())]), |
|
[1] |
|
); |
|
|
|
await ortSessionA.release(); |
|
ortSessionA = null; |
|
|
|
({ hidden_states, position_ids } = await ortSessionD.run({ |
|
"hidden_states.1": hidden_states, |
|
image_embed, |
|
ids_len, |
|
ids_len_minus, |
|
split_factor, |
|
})); |
|
|
|
await ortSessionD.release(); |
|
ortSessionD = null; |
|
} |
|
|
|
let output = ''; |
|
|
|
while ( |
|
num_decode < MAX_SINGLE_CHAT_LENGTH && |
|
Number(history_len.data[0]) < MAX_SEQ_LENGTH |
|
) { |
|
let token_id; |
|
|
|
({ |
|
max_logit_ids: token_id, |
|
past_key_states: past_key_states, |
|
past_value_states: past_value_states, |
|
} = await ortSessionE.run({ |
|
hidden_states, |
|
attention_mask, |
|
"past_key_states.1": past_key_states, |
|
"past_value_states.1": past_value_states, |
|
history_len, |
|
ids_len, |
|
position_ids, |
|
pos_factor, |
|
})); |
|
|
|
if (token_id === 151643 || token_id === 151645) { |
|
break; |
|
} |
|
|
|
num_decode++; |
|
if (num_decode < 2) { |
|
history_len = history_len.add(BigInt(ids_len.data[0])); |
|
|
|
ids_len = new ort.Tensor("int64", new BigInt64Array([1n]), [1]); |
|
|
|
attention_mask = new ort.Tensor("float16", new Uint16Array([0]), [1]); |
|
|
|
if (vision) { |
|
pos_factor = new Tensor( |
|
"float16", |
|
new Uint16Array([int64ToFloat16(pos_factor_v + ids_len.data[0])]), |
|
[1] |
|
); |
|
} else { |
|
pos_factor = new Tensor( |
|
"float16", |
|
new Uint16Array([int64ToFloat16(history_len.data[0] + BigInt(1))]), |
|
[1] |
|
); |
|
} |
|
|
|
} else { |
|
history_len = history_len.add(BigInt(1)); |
|
pos_factor = pos_factor.map((v) => |
|
int64ToFloat16(float16ToInt64(v) + BigInt(1)) |
|
); |
|
} |
|
(input_ids.data)[0] = Number(token_id.data[0]); |
|
|
|
const result_B = await ortSessionB.run({ |
|
input_ids: input_ids, |
|
ids_len: ids_len, |
|
}); |
|
hidden_states = result_B.hidden_states; |
|
|
|
if ( |
|
!Number.isInteger(token_id.data[0]) && |
|
!["bigint", "number"].includes(typeof token_id.data[0]) |
|
) { |
|
throw new Error(`Token ID is not an integer`); |
|
} else { |
|
const decoded = tokenizer.decode([...token_id.data]); |
|
cb(output); |
|
|
|
output += decoded; |
|
} |
|
} |
|
return output; |
|
} |
|
|
|
async function updatePreview(url) { |
|
const image = await RawImage.fromURL(url); |
|
const ar = image.width / image.height; |
|
const [cw, ch] = (ar > 1) ? [320, 320 / ar] : [320 * ar, 320]; |
|
thumb.style.width = `${cw}px`; |
|
thumb.style.height = `${ch}px`; |
|
thumb.style.backgroundImage = `url(${url})`; |
|
thumb.innerHTML = ''; |
|
} |
|
|
|
await initializeSessions(); |
|
|
|
|
|
exampleButton.addEventListener('click', (e) => { |
|
e.preventDefault(); |
|
e.stopPropagation(); |
|
currentImage = EXAMPLE_URL; |
|
}); |
|
|
|
uploadInput.addEventListener('change', (e) => { |
|
console.log('upload change'); |
|
const file = e.target.files[0]; |
|
if (!file) return; |
|
|
|
const reader = new FileReader(); |
|
reader.onload = (e2) => { |
|
currentImage = e2.target.result; |
|
updatePreview(currentImage); |
|
}; |
|
reader.readAsDataURL(file); |
|
}); |
|
|
|
promptInput.addEventListener('keypress', (e) => { |
|
currentQuery = e.target.value; |
|
}); |
|
|
|
form.addEventListener('submit', (e) => { |
|
e.preventDefault(); |
|
|
|
if (!currentImage || !currentQuery) { |
|
status.textContent = 'Please select an image and type a prompt'; |
|
} else { |
|
handleQuery(currentImage, currentQuery); |
|
} |
|
}); |