| const path = require("path"); |
| const fs = require("fs"); |
|
|
| class NativeEmbeddingReranker { |
| static #model = null; |
| static #tokenizer = null; |
| static #transformers = null; |
| static #initializationPromise = null; |
|
|
| |
| |
| |
| #fallbackHost = "https://cdn.anythingllm.com/support/models/"; |
|
|
| constructor() { |
| |
| |
| this.model = "Xenova/ms-marco-MiniLM-L-6-v2"; |
| this.cacheDir = path.resolve( |
| process.env.STORAGE_DIR |
| ? path.resolve(process.env.STORAGE_DIR, `models`) |
| : path.resolve(__dirname, `../../../storage/models`) |
| ); |
| this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/")); |
| |
| if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir); |
|
|
| this.modelDownloaded = fs.existsSync( |
| path.resolve(this.cacheDir, this.model) |
| ); |
| this.log("Initialized"); |
| } |
|
|
| log(text, ...args) { |
| console.log(`\x1b[36m[NativeEmbeddingReranker]\x1b[0m ${text}`, ...args); |
| } |
|
|
| |
| |
| |
| |
| |
| get host() { |
| if (!NativeEmbeddingReranker.#transformers) return "https://huggingface.co"; |
| try { |
| return new URL(NativeEmbeddingReranker.#transformers.env.remoteHost).host; |
| } catch (e) { |
| return this.#fallbackHost; |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| async preload() { |
| try { |
| this.log(`Preloading reranker suite...`); |
| await this.initClient(); |
| this.log( |
| `Preloaded reranker suite. Reranking is available as a service now.` |
| ); |
| return; |
| } catch (e) { |
| console.error(e); |
| this.log( |
| `Failed to preload reranker suite. Reranking will be available on the first rerank call.` |
| ); |
| return; |
| } |
| } |
|
|
| async initClient() { |
| if ( |
| NativeEmbeddingReranker.#transformers && |
| NativeEmbeddingReranker.#model && |
| NativeEmbeddingReranker.#tokenizer |
| ) { |
| this.log(`Reranker suite already fully initialized - reusing.`); |
| return; |
| } |
|
|
| if (NativeEmbeddingReranker.#initializationPromise) { |
| this.log(`Waiting for existing initialization to complete...`); |
| await NativeEmbeddingReranker.#initializationPromise; |
| return; |
| } |
|
|
| NativeEmbeddingReranker.#initializationPromise = (async () => { |
| try { |
| const { AutoModelForSequenceClassification, AutoTokenizer, env } = |
| await import("@xenova/transformers"); |
| this.log(`Loading reranker suite...`); |
| NativeEmbeddingReranker.#transformers = { |
| AutoModelForSequenceClassification, |
| AutoTokenizer, |
| env, |
| }; |
| |
| |
| |
| |
| await this.#getPreTrainedModel(); |
| await this.#getPreTrainedTokenizer(); |
| } finally { |
| NativeEmbeddingReranker.#initializationPromise = null; |
| } |
| })(); |
|
|
| await NativeEmbeddingReranker.#initializationPromise; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| async #getPreTrainedModel() { |
| if (NativeEmbeddingReranker.#model) { |
| this.log(`Loading model from singleton...`); |
| return NativeEmbeddingReranker.#model; |
| } |
|
|
| try { |
| const model = |
| await NativeEmbeddingReranker.#transformers.AutoModelForSequenceClassification.from_pretrained( |
| this.model, |
| { |
| progress_callback: (p) => { |
| if (!this.modelDownloaded && p.status === "progress") { |
| this.log( |
| `[${this.host}] Loading model ${this.model}... ${p?.progress}%` |
| ); |
| } |
| }, |
| cache_dir: this.cacheDir, |
| } |
| ); |
| this.log(`Loaded model ${this.model}`); |
| NativeEmbeddingReranker.#model = model; |
| return model; |
| } catch (e) { |
| this.log( |
| `Failed to load model ${this.model} from ${this.host}.`, |
| e.message, |
| e.stack |
| ); |
| if ( |
| NativeEmbeddingReranker.#transformers.env.remoteHost === |
| this.#fallbackHost |
| ) { |
| this.log(`Failed to load model ${this.model} from fallback host.`); |
| throw e; |
| } |
|
|
| this.log(`Falling back to fallback host. ${this.#fallbackHost}`); |
| NativeEmbeddingReranker.#transformers.env.remoteHost = this.#fallbackHost; |
| NativeEmbeddingReranker.#transformers.env.remotePathTemplate = "{model}/"; |
| return await this.#getPreTrainedModel(); |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| async #getPreTrainedTokenizer() { |
| if (NativeEmbeddingReranker.#tokenizer) { |
| this.log(`Loading tokenizer from singleton...`); |
| return NativeEmbeddingReranker.#tokenizer; |
| } |
|
|
| try { |
| const tokenizer = |
| await NativeEmbeddingReranker.#transformers.AutoTokenizer.from_pretrained( |
| this.model, |
| { |
| progress_callback: (p) => { |
| if (!this.modelDownloaded && p.status === "progress") { |
| this.log( |
| `[${this.host}] Loading tokenizer ${this.model}... ${p?.progress}%` |
| ); |
| } |
| }, |
| cache_dir: this.cacheDir, |
| } |
| ); |
| this.log(`Loaded tokenizer ${this.model}`); |
| NativeEmbeddingReranker.#tokenizer = tokenizer; |
| return tokenizer; |
| } catch (e) { |
| this.log( |
| `Failed to load tokenizer ${this.model} from ${this.host}.`, |
| e.message, |
| e.stack |
| ); |
| if ( |
| NativeEmbeddingReranker.#transformers.env.remoteHost === |
| this.#fallbackHost |
| ) { |
| this.log(`Failed to load tokenizer ${this.model} from fallback host.`); |
| throw e; |
| } |
|
|
| this.log(`Falling back to fallback host. ${this.#fallbackHost}`); |
| NativeEmbeddingReranker.#transformers.env.remoteHost = this.#fallbackHost; |
| NativeEmbeddingReranker.#transformers.env.remotePathTemplate = "{model}/"; |
| return await this.#getPreTrainedTokenizer(); |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| async rerank(query, documents, options = { topK: 4 }) { |
| await this.initClient(); |
| const model = NativeEmbeddingReranker.#model; |
| const tokenizer = NativeEmbeddingReranker.#tokenizer; |
|
|
| const start = Date.now(); |
| this.log(`Reranking ${documents.length} documents...`); |
| const inputs = tokenizer(new Array(documents.length).fill(query), { |
| text_pair: documents.map((doc) => doc.text), |
| padding: true, |
| truncation: true, |
| }); |
| const { logits } = await model(inputs); |
| const reranked = logits |
| .sigmoid() |
| .tolist() |
| .map(([score], i) => ({ |
| rerank_corpus_id: i, |
| rerank_score: score, |
| ...documents[i], |
| })) |
| .sort((a, b) => b.rerank_score - a.rerank_score) |
| .slice(0, options.topK); |
|
|
| this.log( |
| `Reranking ${documents.length} documents to top ${options.topK} took ${Date.now() - start}ms` |
| ); |
| return reranked; |
| } |
| } |
|
|
| module.exports = { |
| NativeEmbeddingReranker, |
| }; |
|
|