pdufour's picture
Update index.js
9a35308 verified
raw
history blame
8.01 kB
import { env, AutoTokenizer, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers';
import { getModelJSON, getModelFile } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2/src/utils/hub.js";
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.0/dist/ort.webgpu.mjs";
const EXAMPLE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg";
const INPUT_IMAGE_SIZE = [960, 960];
const HEIGHT_FACTOR = 10;
const WIDTH_FACTOR = 10;
const IMAGE_EMBED_SIZE = WIDTH_FACTOR * HEIGHT_FACTOR;
const MAX_SEQ_LENGTH = 1024;
const BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct";
const ONNX_MODEL = "pdufour/Qwen2-VL-2B-Instruct-ONNX-Q4-F16";
const QUANT = "q4f16";
const MAX_SINGLE_CHAT_LENGTH = 10;
// UI Elements
const status = document.getElementById('status');
const fileUpload = document.getElementById('upload');
const imageContainer = document.getElementById('container');
const example = document.getElementById('example');
let ortSessionA, ortSessionB, ortSessionC;
async function initializeSessions() {
status.textContent = 'Loading model...';
ortSessionA = await ort.InferenceSession.create(
await getModelFile(ONNX_MODEL, `onnx/QwenVL_A_${QUANT}.onnx`),
{ executionProviders: ["webgpu"] }
);
console.log({ortSessionA});
ortSessionB = await ort.InferenceSession.create(
await getModelFile(ONNX_MODEL, `onnx/QwenVL_B_${QUANT}.onnx`),
{ executionProviders: ["webgpu"] }
);
console.log({ortSessionB});
ortSessionC = await ort.InferenceSession.create(
await getModelFile(ONNX_MODEL, `onnx/QwenVL_C_${QUANT}.onnx`),
{ executionProviders: ["webgpu"] }
);
console.log({ortSessionC});
status.textContent = 'Ready';
}
// UI Event Handlers
example.addEventListener('click', (e) => {
e.preventDefault();
parse(EXAMPLE_URL, 'Describe this image.');
});
fileUpload.addEventListener('change', function(e) {
const file = e.target.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = e2 => parse(e2.target.result, '');
reader.readAsDataURL(file);
});
async function parse(img, txt) {
imageContainer.innerHTML = '';
imageContainer.style.backgroundImage = `url(${img})`;
status.textContent = 'Analysing...';
const output = await imageTextToText(img, txt);
status.textContent = output;
}
export async function imageTextToText(
imagePath,
query,
vision = true
) {
let ortSessionA, ortSessionB, ortSessionC, ortSessionD, ortSessionE;
const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]);
let position_ids;
let num_decode = 0;
let history_len = new Tensor("int64", new BigInt64Array([0n]), [1]);
var pos_factor_v = BigInt(1 - IMAGE_EMBED_SIZE + WIDTH_FACTOR);
let past_key_states = new ort.Tensor(
"float16",
new Uint16Array(
config.num_hidden_layers *
config.num_key_value_heads *
MAX_SEQ_LENGTH *
(config.hidden_size / config.num_attention_heads)
).fill(0),
[
config.num_hidden_layers,
config.num_key_value_heads,
MAX_SEQ_LENGTH,
config.hidden_size / config.num_attention_heads,
]
);
let past_value_states = past_key_states;
let attention_mask = new ort.Tensor(
"float16",
new Uint16Array([0xfbff]),
[1]
);
let pos_factor = new Tensor("float16", new Uint16Array([0]), [1]);
logger.groupCollapsed("[TOKENIZATION] Processing prompt...");
const tokenizer = await AutoTokenizer.from_pretrained(BASE_MODEL);
const prompt = `\n<|im_start|>user\n<|vision_start|><|vision_end|>${query}<|im_end|>\n<|im_start|>assistant\n`;
const token = await tokenizer(prompt, {
return_tensors: "pt",
add_generation_prompt: false,
tokenize: true,
}).input_ids;
const seq_length = token.dims[1];
let ids_len = new Tensor("int64", new BigInt64Array([BigInt(seq_length)]), [
1,
]);
let input_ids = new ort.Tensor(
"int32",
new Int32Array(MAX_SEQ_LENGTH).fill(0),
[MAX_SEQ_LENGTH]
);
input_ids.data.set(Array.from(token.data.slice(0, seq_length), Number));
const dummy = new ort.Tensor("int32", new Int32Array([0]), []);
let { hidden_states } = await ortSessionB.run({
input_ids: input_ids,
ids_len: ids_len,
});
({ position_ids } = await ortSessionC.run({
dummy: dummy,
}));
// Process image
if (vision) {
let image = await RawImage.fromURL(imagePath);
image = await image.resize(INPUT_IMAGE_SIZE[0], INPUT_IMAGE_SIZE[1]);
image = image.rgb();
image = image.toTensor("CHW");
image = image.to("float32");
image = image.div_(255.0);
const pixel_values = image.unsqueeze(0);
const { image_embed } = await ortSessionA.run({
pixel_values: pixel_values,
});
ids_len = ids_len.add(BigInt(IMAGE_EMBED_SIZE));
const split_factor = new Tensor(
"int32",
new Int32Array([
MAX_SEQ_LENGTH - Number(ids_len.item()) - IMAGE_EMBED_SIZE,
]),
[1]
);
const ids_len_minus = new Tensor(
"int32",
new Int32Array([Number(ids_len.item()) - Number(prompt_head_len.item())]),
[1]
);
await ortSessionA.release();
ortSessionA = null;
logger.log("session d create");
ortSessionD = await ort.InferenceSession.create(
await getModelFile(ONNX_MODEL, `onnx/QwenVL_D_${QUANT}.onnx`),
{
executionProviders: ["webgpu"],
}
);
({ hidden_states, position_ids } = await ortSessionD.run({
"hidden_states.1": hidden_states,
image_embed,
ids_len,
ids_len_minus,
split_factor,
}));
await ortSessionD.release();
ortSessionD = null;
}
let output = '';
while (
num_decode < MAX_SINGLE_CHAT_LENGTH &&
Number(history_len.data[0]) < MAX_SEQ_LENGTH
) {
let token_id;
if (!ortSessionE) {
console.log("Create ortSessionE");
ortSessionE = await ort.InferenceSession.create(
await getModelFile(ONNX_MODEL, `onnx/QwenVL_E_${QUANT}.onnx`),
{
executionProviders: ["wasm"],
},
);
}
({
max_logit_ids: token_id,
past_key_states: past_key_states,
past_value_states: past_value_states,
} = await ortSessionE.run({
hidden_states,
attention_mask,
"past_key_states.1": past_key_states,
"past_value_states.1": past_value_states,
history_len,
ids_len,
position_ids,
pos_factor,
}));
if (token_id === 151643 || token_id === 151645) {
break;
}
num_decode++;
if (num_decode < 2) {
history_len = history_len.add(BigInt(ids_len.data[0]));
ids_len = new ort.Tensor("int64", new BigInt64Array([1n]), [1]);
attention_mask = new ort.Tensor("float16", new Uint16Array([0]), [1]);
if (vision) {
pos_factor = new Tensor(
"float16",
new Uint16Array([int64ToFloat16(pos_factor_v + ids_len.data[0])]),
[1]
);
} else {
pos_factor = new Tensor(
"float16",
new Uint16Array([int64ToFloat16(history_len.data[0] + BigInt(1))]),
[1]
);
}
} else {
history_len = history_len.add(BigInt(1));
pos_factor = pos_factor.map((v) =>
int64ToFloat16(float16ToInt64(v) + BigInt(1))
);
logger.tensor("Updated history_len", history_len);
logger.tensor("Updated pos_factor", pos_factor);
logger.groupEnd();
}
(input_ids.data)[0] = Number(token_id.data[0]);
const result_B = await ortSessionB.run({
input_ids: input_ids,
ids_len: ids_len,
});
hidden_states = result_B.hidden_states;
if (
!Number.isInteger(token_id.data[0]) &&
!["bigint", "number"].includes(typeof token_id.data[0])
) {
throw new Error(`Token ID is not an integer`);
} else {
const decoded = tokenizer.decode([...token_id.data])
output += decoded;
}
}
}
await initializeSessions();