|
import type { Tensor } from "@xenova/transformers"; |
|
import { pipeline, dot } from "@xenova/transformers"; |
|
|
|
|
|
function innerProduct(tensor1: Tensor, tensor2: Tensor) { |
|
return 1.0 - dot(tensor1.data, tensor2.data); |
|
} |
|
|
|
const modelId = "Xenova/gte-small"; |
|
const extractor = await pipeline("feature-extraction", modelId); |
|
|
|
export const MAX_SEQ_LEN = 512 as const; |
|
|
|
export async function findSimilarSentences( |
|
query: string, |
|
sentences: string[], |
|
{ topK = 5 }: { topK: number } |
|
) { |
|
const input = [query, ...sentences]; |
|
const output: Tensor = await extractor(input, { pooling: "mean", normalize: true }); |
|
|
|
const queryTensor: Tensor = output[0]; |
|
const sentencesTensor: Tensor = output.slice([1, input.length - 1]); |
|
|
|
const distancesFromQuery: { distance: number; index: number }[] = [...sentencesTensor].map( |
|
(sentenceTensor: Tensor, index: number) => { |
|
return { |
|
distance: innerProduct(queryTensor, sentenceTensor), |
|
index: index, |
|
}; |
|
} |
|
); |
|
|
|
distancesFromQuery.sort((a, b) => { |
|
return a.distance - b.distance; |
|
}); |
|
|
|
|
|
return distancesFromQuery.slice(0, topK).map((item) => item.index); |
|
} |
|
|