File size: 6,088 Bytes
efc1906 3c8fb7b 37f2943 94a91a9 8724308 50f59bb 37f2943 94a91a9 37f2943 50f59bb 37f2943 50f59bb 37f2943 94a91a9 37f2943 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import { env, AutoTokenizer, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers';
import { getModelJSON } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2/src/utils/hub.js";
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.0/dist/ort.webgpu.mjs";
// Since we will download the model from the Hugging Face Hub, we can skip the local model check
env.allowLocalModels = false;
// Reference the elements that we will need
const status = document.getElementById('status');
const fileUpload = document.getElementById('upload');
const imageContainer = document.getElementById('container');
const example = document.getElementById('example');
const EXAMPLE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg";
const INPUT_IMAGE_SIZE = [960, 960];
const HEIGHT_FACTOR = 10;
const WIDTH_FACTOR = 10;
const IMAGE_EMBED_SIZE = WIDTH_FACTOR * HEIGHT_FACTOR;
const MAX_SEQ_LENGTH = 1024;
const ONNX_URL = "http://localhost:3004/onnx";
const BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct";
const QUANTIZATION = "q4f16";
const MAX_SINGLE_CHAT_LENGTH = 10;
status.textContent = 'Loading model...';
status.textContent = 'Ready';
example.addEventListener('click', (e) => {
e.preventDefault();
detect(EXAMPLE_URL);
});
fileUpload.addEventListener('change', function (e) {
const file = e.target.files[0];
if (!file) {
return;
}
const reader = new FileReader();
// Set up a callback when the file is loaded
reader.onload = e2 => detect(e2.target.result);
reader.readAsDataURL(file);
});
// Detect objects in the image
async function detect(img) {
imageContainer.innerHTML = '';
imageContainer.style.backgroundImage = `url(${img})`;
status.textContent = 'Analysing...';
const output = await detector(img, {
threshold: 0.5,
percentage: true,
});
status.textContent = '';
output.forEach(renderBox);
}
export async function imageTextToText(
imagePath,
query,
vision = true
) {
const suffix = QUANTIZATION ? `_${QUANTIZATION}` : "";
const config = (await getModelJSON(BASE_MODEL, "config.json"))
const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]);
let position_ids;
let num_decode = 0;
let history_len = new Tensor("int64", new BigInt64Array([0n]), [1]);
let past_key_states = new ort.Tensor(
"float16",
new Uint16Array(
config.num_hidden_layers *
config.num_key_value_heads *
MAX_SEQ_LENGTH *
(config.hidden_size / config.num_attention_heads)
).fill(0),
[
config.num_hidden_layers,
config.num_key_value_heads,
MAX_SEQ_LENGTH,
config.hidden_size / config.num_attention_heads,
]
);
let past_value_states = past_key_states;
let attention_mask = new ort.Tensor(
"float16",
new Uint16Array([0xfbff]), // -65504.0 in float16
[1]
);
let pos_factor = new Tensor("float16", new Uint16Array([0]), [1]);
const tokenizer = await AutoTokenizer.from_pretrained(BASE_MODEL);
const prompt = `\n<|im_start|>user\n<|vision_start|><|vision_end|>${query}<|im_end|>\n<|im_start|>assistant\n`;
const token = await tokenizer(prompt, {
return_tensors: "pt",
add_generation_prompt: false,
tokenize: true,
}).input_ids;
const seq_length = token.dims[1];
let ids_len = new Tensor("int64", new BigInt64Array([BigInt(seq_length)]), [
1,
]);
let input_ids = new ort.Tensor(
"int32",
new Int32Array(MAX_SEQ_LENGTH).fill(0),
[MAX_SEQ_LENGTH]
);
input_ids.data.set(Array.from(token.data.slice(0, seq_length), Number));
if (vision) {
let image = await RawImage.fromURL(imagePath);
image = image.rgb().toTensor("CHW").to("float32").div_(255.0);
const pixel_values = image.unsqueeze(0);
const ortSessionA = await ort.InferenceSession.create(
`${BASE_URL}/QwenVL_A${suffix}.onnx`,
{ executionProviders: ["webgpu"] }
);
const { image_embed } = await ortSessionA.run({ pixel_values });
ids_len = ids_len.add(BigInt(IMAGE_EMBED_SIZE));
const ortSessionD = await ort.InferenceSession.create(
`${BASE_URL}/QwenVL_D${suffix}.onnx`,
{ executionProviders: ["webgpu"] }
);
({ hidden_states: past_key_states, position_ids } =
await ortSessionD.run({
"hidden_states.1": past_key_states,
image_embed,
ids_len,
"ids_len_minus": new Tensor(
"int32",
new Int32Array([Number(ids_len.item()) - Number(prompt_head_len.item())]),
[1]
),
"split_factor": new Tensor(
"int32",
new Int32Array([
MAX_SEQ_LENGTH - Number(ids_len.item()) - IMAGE_EMBED_SIZE,
]),
[1]
),
}));
}
const ortSessionB = await ort.InferenceSession.create(
`${BASE_URL}/QwenVL_B${suffix}.onnx`,
{ executionProviders: ["webgpu"] }
);
while (
num_decode < MAX_SINGLE_CHAT_LENGTH &&
Number(history_len.data[0]) < MAX_SEQ_LENGTH
) {
const ortSessionE = await ort.InferenceSession.create(
`${BASE_URL}/QwenVL_E_q4f16.onnx`,
{ executionProviders: ["wasm"] }
);
const result = await ortSessionE.run({
hidden_states: past_key_states,
attention_mask,
"past_key_states.1": past_key_states,
"past_value_states.1": past_value_states,
history_len,
ids_len,
position_ids,
pos_factor,
});
const token_id = result.max_logit_ids;
if (token_id === 151643 || token_id === 151645) break;
num_decode++;
history_len = history_len.add(BigInt(1));
pos_factor = new Tensor(
"float16",
new Uint16Array([Number(pos_factor.data[0]) + 1]),
[1]
);
past_key_states = result.past_key_states;
past_value_states = result.past_value_states;
input_ids.data[0] = Number(token_id.data[0]);
const { hidden_states } = await ortSessionB.run({
input_ids,
ids_len,
});
past_key_states = hidden_states;
}
}
|