|
import { pipeline, env } from "@huggingface/transformers";
|
|
|
|
|
|
env.allowLocalModels = false;
|
|
|
|
async function supportsWebGPU() {
|
|
try {
|
|
if (!navigator.gpu) return false;
|
|
await navigator.gpu.requestAdapter();
|
|
return true;
|
|
} catch (e) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
const device = (await supportsWebGPU()) ? "webgpu" : "wasm";
|
|
|
|
class PipelineManager {
|
|
static defaultConfigs = {
|
|
"text-classification": {
|
|
model: "onnx-community/rubert-tiny-sentiment-balanced-ONNX",
|
|
},
|
|
"image-classification": {
|
|
model: "onnx-community/mobilenet_v2_1.0_224",
|
|
},
|
|
};
|
|
static instances = {};
|
|
static currentTask = "text-classification";
|
|
static currentModel = PipelineManager.defaultConfigs["text-classification"].model;
|
|
static queue = [];
|
|
static isProcessing = false;
|
|
|
|
static async getInstance(task, modelName, progress_callback = null) {
|
|
const key = `${task}:${modelName}`;
|
|
if (!this.instances[key]) {
|
|
self.postMessage({ status: "initiate", file: modelName, task });
|
|
this.instances[key] = await pipeline(task, modelName, { progress_callback, device: device});
|
|
self.postMessage({ status: "ready", file: modelName, task });
|
|
}
|
|
return this.instances[key];
|
|
}
|
|
|
|
static async processQueue() {
|
|
if (this.isProcessing || this.queue.length === 0) return;
|
|
|
|
this.isProcessing = true;
|
|
const { input, task, modelName } = this.queue[this.queue.length - 1];
|
|
this.queue = [];
|
|
|
|
try {
|
|
const classifier = await this.getInstance(task, modelName, (x) => {
|
|
self.postMessage({
|
|
...x,
|
|
status: x.status || "progress",
|
|
file: x.file || modelName,
|
|
name: modelName,
|
|
task,
|
|
loaded: x.loaded,
|
|
total: x.total,
|
|
progress: x.loaded && x.total ? (x.loaded / x.total) * 100 : 0,
|
|
});
|
|
});
|
|
|
|
let output;
|
|
if (task === "image-classification") {
|
|
|
|
output = await classifier(input, { top_k: 5 });
|
|
} else if (task === "automatic-speech-recognition") {
|
|
output = await classifier(input);
|
|
} else {
|
|
output = await classifier(input, { top_k: 5 });
|
|
}
|
|
|
|
self.postMessage({
|
|
status: "complete",
|
|
output,
|
|
file: modelName,
|
|
task,
|
|
});
|
|
} catch (error) {
|
|
self.postMessage({
|
|
status: "error",
|
|
error: error.message,
|
|
file: modelName,
|
|
task,
|
|
});
|
|
}
|
|
|
|
this.isProcessing = false;
|
|
if (this.queue.length > 0) {
|
|
this.processQueue();
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
self.addEventListener("message", async (event) => {
|
|
const { input, modelName, task, action } = event.data;
|
|
|
|
|
|
|
|
if (action === "load-model") {
|
|
PipelineManager.currentTask = task || "text-classification";
|
|
PipelineManager.currentModel =
|
|
modelName ||
|
|
PipelineManager.defaultConfigs[PipelineManager.currentTask].model;
|
|
|
|
await PipelineManager.getInstance(
|
|
PipelineManager.currentTask,
|
|
PipelineManager.currentModel,
|
|
(x) => {
|
|
self.postMessage({
|
|
...x,
|
|
file: PipelineManager.currentModel,
|
|
status: x.status || "progress",
|
|
loaded: x.loaded,
|
|
total: x.total,
|
|
task: PipelineManager.currentTask,
|
|
});
|
|
}
|
|
);
|
|
return;
|
|
}
|
|
|
|
PipelineManager.queue.push({
|
|
input,
|
|
task: task || PipelineManager.currentTask,
|
|
modelName: modelName || PipelineManager.currentModel,
|
|
});
|
|
PipelineManager.processQueue();
|
|
}); |