File size: 2,852 Bytes
3a01622
 
 
 
 
 
 
 
41f8b74
3a01622
 
6e0b0ea
 
3a01622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e0b0ea
3a01622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import { TEXT_EMBEDDING_MODELS } from "$env/static/private";

import { z } from "zod";
import { sum } from "$lib/utils/sum";
import {
	embeddingEndpoints,
	embeddingEndpointSchema,
	type EmbeddingEndpoint,
} from "$lib/server/embeddingEndpoints/embeddingEndpoints";
import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";

import JSON5 from "json5";

const modelConfig = z.object({
	/** Used as an identifier in DB */
	id: z.string().optional(),
	/** Used to link to the model page, and for inference */
	name: z.string().min(1),
	displayName: z.string().min(1).optional(),
	description: z.string().min(1).optional(),
	websiteUrl: z.string().url().optional(),
	modelUrl: z.string().url().optional(),
	endpoints: z.array(embeddingEndpointSchema).nonempty(),
	chunkCharLength: z.number().positive(),
	preQuery: z.string().default(""),
	prePassage: z.string().default(""),
});

// Default embedding model for backward compatibility
const rawEmbeddingModelJSON =
	TEXT_EMBEDDING_MODELS ||
	`[
	{
	  "name": "Xenova/gte-small",
	  "chunkCharLength": 512,
	  "endpoints": [
		{ "type": "transformersjs" }
	  ]
	}
]`;

const embeddingModelsRaw = z.array(modelConfig).parse(JSON5.parse(rawEmbeddingModelJSON));

const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({
	...m,
	id: m.id || m.name,
});

const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
	...m,
	getEndpoint: async (): Promise<EmbeddingEndpoint> => {
		if (!m.endpoints) {
			return embeddingEndpointTransformersJS({
				type: "transformersjs",
				weight: 1,
				model: m,
			});
		}

		const totalWeight = sum(m.endpoints.map((e) => e.weight));

		let random = Math.random() * totalWeight;

		for (const endpoint of m.endpoints) {
			if (random < endpoint.weight) {
				const args = { ...endpoint, model: m };

				switch (args.type) {
					case "tei":
						return embeddingEndpoints.tei(args);
					case "transformersjs":
						return embeddingEndpoints.transformersjs(args);
				}
			}

			random -= endpoint.weight;
		}

		throw new Error(`Failed to select embedding endpoint`);
	},
});

export const embeddingModels = await Promise.all(
	embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))
);

export const defaultEmbeddingModel = embeddingModels[0];

const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => {
	return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]);
};

export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => {
	return validateEmbeddingModel(_models, "id");
};

export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => {
	return validateEmbeddingModel(_models, "name");
};

export type EmbeddingBackendModel = typeof defaultEmbeddingModel;