File size: 1,709 Bytes
683ef2f
ebac87f
 
 
 
 
 
 
683ef2f
 
 
 
 
 
 
 
 
 
 
 
3acc11d
 
ebac87f
 
 
 
 
 
aa07e29
683ef2f
 
ebac87f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import type { Tensor, Pipeline } from "@xenova/transformers";
import { pipeline, dot } from "@xenova/transformers";

// see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34
function innerProduct(tensor1: Tensor, tensor2: Tensor) {
	return 1.0 - dot(tensor1.data, tensor2.data);
}

// Use the Singleton pattern to enable lazy construction of the pipeline.
class PipelineSingleton {
	static modelId = "Xenova/gte-small";
	static instance: Promise<Pipeline> | null = null;
	static async getInstance() {
		if (this.instance === null) {
			this.instance = pipeline("feature-extraction", this.modelId);
		}
		return this.instance;
	}
}

// see https://huggingface.co/thenlper/gte-small/blob/d8e2604cadbeeda029847d19759d219e0ce2e6d8/README.md?code=true#L2625
export const MAX_SEQ_LEN = 512 as const;

export async function findSimilarSentences(
	query: string,
	sentences: string[],
	{ topK = 5 }: { topK: number }
) {
	const input = [query, ...sentences];

	const extractor = await PipelineSingleton.getInstance();
	const output: Tensor = await extractor(input, { pooling: "mean", normalize: true });

	const queryTensor: Tensor = output[0];
	const sentencesTensor: Tensor = output.slice([1, input.length - 1]);

	const distancesFromQuery: { distance: number; index: number }[] = [...sentencesTensor].map(
		(sentenceTensor: Tensor, index: number) => {
			return {
				distance: innerProduct(queryTensor, sentenceTensor),
				index: index,
			};
		}
	);

	distancesFromQuery.sort((a, b) => {
		return a.distance - b.distance;
	});

	// Return the indexes of the closest topK sentences
	return distancesFromQuery.slice(0, topK).map((item) => item.index);
}