File size: 1,892 Bytes
3a01622
41f8b74
3a01622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import { z } from "zod";
import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
import { chunk } from "$lib/utils/chunk";

export const embeddingEndpointTeiParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("tei"),
	url: z.string().url(),
	authorization: z.string().optional(),
});

const getModelInfoByUrl = async (url: string, authorization?: string) => {
	const { origin } = new URL(url);

	const response = await fetch(`${origin}/info`, {
		headers: {
			Accept: "application/json",
			"Content-Type": "application/json",
			...(authorization ? { Authorization: authorization } : {}),
		},
	});

	const json = await response.json();
	return json;
};

export async function embeddingEndpointTei(
	input: z.input<typeof embeddingEndpointTeiParametersSchema>
): Promise<EmbeddingEndpoint> {
	const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input);

	const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url);
	const maxBatchSize = Math.min(
		max_client_batch_size,
		Math.floor(max_batch_tokens / model.chunkCharLength)
	);

	return async ({ inputs }) => {
		const { origin } = new URL(url);

		const batchesInputs = chunk(inputs, maxBatchSize);

		const batchesResults = await Promise.all(
			batchesInputs.map(async (batchInputs) => {
				const response = await fetch(`${origin}/embed`, {
					method: "POST",
					headers: {
						Accept: "application/json",
						"Content-Type": "application/json",
						...(authorization ? { Authorization: authorization } : {}),
					},
					body: JSON.stringify({ inputs: batchInputs, normalize: true, truncate: true }),
				});

				const embeddings: Embedding[] = await response.json();
				return embeddings;
			})
		);

		const flatAllEmbeddings = batchesResults.flat();

		return flatAllEmbeddings;
	};
}