KingNish's picture
Upload 29 files
3baa9da verified
raw
history blame contribute delete
3.84 kB
import { pipeline, env } from "@huggingface/transformers";
// Skip local model check
env.allowLocalModels = false;
async function supportsWebGPU() {
try {
if (!navigator.gpu) return false;
await navigator.gpu.requestAdapter();
return true;
} catch (e) {
return false;
}
}
const device = (await supportsWebGPU()) ? "webgpu" : "wasm";
class PipelineManager {
static defaultConfigs = {
"text-classification": {
model: "onnx-community/rubert-tiny-sentiment-balanced-ONNX",
},
"image-classification": {
model: "onnx-community/mobilenet_v2_1.0_224",
},
};
static instances = {}; // key: `${task}:${modelName}` -> pipeline instance
static currentTask = "text-classification";
static currentModel = PipelineManager.defaultConfigs["text-classification"].model;
static queue = [];
static isProcessing = false;
static async getInstance(task, modelName, progress_callback = null) {
const key = `${task}:${modelName}`;
if (!this.instances[key]) {
self.postMessage({ status: "initiate", file: modelName, task });
this.instances[key] = await pipeline(task, modelName, { progress_callback, device: device});
self.postMessage({ status: "ready", file: modelName, task });
}
return this.instances[key];
}
static async processQueue() {
if (this.isProcessing || this.queue.length === 0) return;
this.isProcessing = true;
const { input, task, modelName } = this.queue[this.queue.length - 1];
this.queue = [];
try {
const classifier = await this.getInstance(task, modelName, (x) => {
self.postMessage({
...x,
status: x.status || "progress",
file: x.file || modelName,
name: modelName,
task,
loaded: x.loaded,
total: x.total,
progress: x.loaded && x.total ? (x.loaded / x.total) * 100 : 0,
});
});
let output;
if (task === "image-classification") {
// input is a data URL or Blob
output = await classifier(input, { top_k: 5 });
} else if (task === "automatic-speech-recognition") {
output = await classifier(input);
} else {
output = await classifier(input, { top_k: 5 });
}
self.postMessage({
status: "complete",
output,
file: modelName,
task,
});
} catch (error) {
self.postMessage({
status: "error",
error: error.message,
file: modelName,
task,
});
}
this.isProcessing = false;
if (this.queue.length > 0) {
this.processQueue();
}
}
}
// Listen for messages from the main thread
self.addEventListener("message", async (event) => {
const { input, modelName, task, action } = event.data;
// console.log("Worker received message:", event.data); // Add this line to log the received message t
if (action === "load-model") {
PipelineManager.currentTask = task || "text-classification";
PipelineManager.currentModel =
modelName ||
PipelineManager.defaultConfigs[PipelineManager.currentTask].model;
await PipelineManager.getInstance(
PipelineManager.currentTask,
PipelineManager.currentModel,
(x) => {
self.postMessage({
...x,
file: PipelineManager.currentModel,
status: x.status || "progress",
loaded: x.loaded,
total: x.total,
task: PipelineManager.currentTask,
});
}
);
return;
}
PipelineManager.queue.push({
input,
task: task || PipelineManager.currentTask,
modelName: modelName || PipelineManager.currentModel,
});
PipelineManager.processQueue();
});