|
import React, { createContext, useState, useRef, useCallback } from "react";
|
|
import { AutoProcessor, AutoModelForImageTextToText, RawImage, TextStreamer } from "@huggingface/transformers";
|
|
import type { LlavaProcessor, PreTrainedModel, Tensor } from "@huggingface/transformers";
|
|
import type { VLMContextValue } from "../types/vlm";
|
|
|
|
const VLMContext = createContext<VLMContextValue | null>(null);
|
|
|
|
const MODEL_ID = "onnx-community/FastVLM-0.5B-ONNX";
|
|
const MAX_NEW_TOKENS = 512;
|
|
|
|
export { VLMContext };
|
|
|
|
export const VLMProvider: React.FC<React.PropsWithChildren> = ({ children }) => {
|
|
const [isLoaded, setIsLoaded] = useState(false);
|
|
const [isLoading, setIsLoading] = useState(false);
|
|
const [error, setError] = useState<string | null>(null);
|
|
|
|
const processorRef = useRef<LlavaProcessor | null>(null);
|
|
const modelRef = useRef<PreTrainedModel | null>(null);
|
|
const loadPromiseRef = useRef<Promise<void> | null>(null);
|
|
const inferenceLock = useRef(false);
|
|
const canvasRef = useRef<HTMLCanvasElement | null>(null);
|
|
|
|
const loadModel = useCallback(
|
|
async (onProgress?: (msg: string) => void) => {
|
|
if (isLoaded) {
|
|
onProgress?.("Model already loaded!");
|
|
return;
|
|
}
|
|
|
|
if (loadPromiseRef.current) {
|
|
return loadPromiseRef.current;
|
|
}
|
|
|
|
setIsLoading(true);
|
|
setError(null);
|
|
|
|
loadPromiseRef.current = (async () => {
|
|
try {
|
|
onProgress?.("Loading processor...");
|
|
processorRef.current = await AutoProcessor.from_pretrained(MODEL_ID);
|
|
onProgress?.("Processor loaded. Loading model...");
|
|
modelRef.current = await AutoModelForImageTextToText.from_pretrained(MODEL_ID, {
|
|
dtype: {
|
|
embed_tokens: "fp16",
|
|
vision_encoder: "q4",
|
|
decoder_model_merged: "q4",
|
|
},
|
|
device: "webgpu",
|
|
});
|
|
onProgress?.("Model loaded successfully!");
|
|
setIsLoaded(true);
|
|
} catch (e) {
|
|
const errorMessage = e instanceof Error ? e.message : String(e);
|
|
setError(errorMessage);
|
|
console.error("Error loading model:", e);
|
|
throw e;
|
|
} finally {
|
|
setIsLoading(false);
|
|
loadPromiseRef.current = null;
|
|
}
|
|
})();
|
|
|
|
return loadPromiseRef.current;
|
|
},
|
|
[isLoaded],
|
|
);
|
|
|
|
const runInference = useCallback(
|
|
async (media: HTMLVideoElement | HTMLImageElement, instruction: string, onTextUpdate?: (text: string) => void): Promise<string> => {
|
|
if (inferenceLock.current) {
|
|
console.log("Inference already running, skipping frame");
|
|
return "";
|
|
}
|
|
inferenceLock.current = true;
|
|
|
|
if (!processorRef.current || !modelRef.current) {
|
|
throw new Error("Model/processor not loaded");
|
|
}
|
|
|
|
if (!canvasRef.current) {
|
|
canvasRef.current = document.createElement("canvas");
|
|
}
|
|
const canvas = canvasRef.current;
|
|
|
|
|
|
let width = 0;
|
|
let height = 0;
|
|
if (media instanceof HTMLVideoElement) {
|
|
width = media.videoWidth;
|
|
height = media.videoHeight;
|
|
} else if (media instanceof HTMLImageElement) {
|
|
width = media.naturalWidth;
|
|
height = media.naturalHeight;
|
|
} else {
|
|
throw new Error("Unsupported media type");
|
|
}
|
|
canvas.width = width;
|
|
canvas.height = height;
|
|
|
|
const ctx = canvas.getContext("2d", { willReadFrequently: true });
|
|
if (!ctx) throw new Error("Could not get canvas context");
|
|
|
|
ctx.drawImage(media, 0, 0, width, height);
|
|
|
|
const frame = ctx.getImageData(0, 0, canvas.width, canvas.height);
|
|
const rawImg = new RawImage(frame.data, frame.width, frame.height, 4);
|
|
const messages = [
|
|
{
|
|
role: "system",
|
|
content: `You are a helpful visual AI assistant. Respond concisely and accurately to the user's query in one sentence.`,
|
|
},
|
|
{ role: "user", content: `<image>${instruction}` },
|
|
];
|
|
const prompt = processorRef.current.apply_chat_template(messages, {
|
|
add_generation_prompt: true,
|
|
});
|
|
const inputs = await processorRef.current(rawImg, prompt, {
|
|
add_special_tokens: false,
|
|
});
|
|
|
|
let streamed = "";
|
|
const streamer = new TextStreamer(processorRef.current.tokenizer!, {
|
|
skip_prompt: true,
|
|
skip_special_tokens: true,
|
|
callback_function: (t: string) => {
|
|
streamed += t;
|
|
onTextUpdate?.(streamed.trim());
|
|
},
|
|
});
|
|
|
|
const outputs = (await modelRef.current.generate({
|
|
...inputs,
|
|
max_new_tokens: MAX_NEW_TOKENS,
|
|
do_sample: false,
|
|
streamer,
|
|
repetition_penalty: 1.2,
|
|
})) as Tensor;
|
|
|
|
const decoded = processorRef.current.batch_decode(outputs.slice(null, [inputs.input_ids.dims.at(-1), null]), {
|
|
skip_special_tokens: true,
|
|
});
|
|
inferenceLock.current = false;
|
|
return decoded[0].trim();
|
|
},
|
|
[],
|
|
);
|
|
|
|
return (
|
|
<VLMContext.Provider
|
|
value={{
|
|
isLoaded,
|
|
isLoading,
|
|
error,
|
|
loadModel,
|
|
runInference,
|
|
}}
|
|
>
|
|
{children}
|
|
</VLMContext.Provider>
|
|
);
|
|
};
|
|
|