Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	File size: 3,089 Bytes
			
			| a8a9533 b17a5c8 d89c536 b17a5c8 922385b b17a5c8 8c28b44 b17a5c8 a8a9533 b17a5c8 922385b b17a5c8 8c28b44 cc2dafe b17a5c8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | import { env } from "$env/dynamic/private";
import { z } from "zod";
import { sum } from "$lib/utils/sum";
import {
	embeddingEndpoints,
	embeddingEndpointSchema,
	type EmbeddingEndpoint,
} from "$lib/server/embeddingEndpoints/embeddingEndpoints";
import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";
import JSON5 from "json5";
const modelConfig = z.object({
	/** Used as an identifier in DB */
	id: z.string().optional(),
	/** Used to link to the model page, and for inference */
	name: z.string().min(1),
	displayName: z.string().min(1).optional(),
	description: z.string().min(1).optional(),
	websiteUrl: z.string().url().optional(),
	modelUrl: z.string().url().optional(),
	endpoints: z.array(embeddingEndpointSchema).nonempty(),
	chunkCharLength: z.number().positive(),
	maxBatchSize: z.number().positive().optional(),
	preQuery: z.string().default(""),
	prePassage: z.string().default(""),
});
// Default embedding model for backward compatibility
const rawEmbeddingModelJSON =
	env.TEXT_EMBEDDING_MODELS ||
	`[
	{
	  "name": "Xenova/gte-small",
	  "chunkCharLength": 512,
	  "endpoints": [
		{ "type": "transformersjs" }
	  ]
	}
]`;
const embeddingModelsRaw = z.array(modelConfig).parse(JSON5.parse(rawEmbeddingModelJSON));
const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({
	...m,
	id: m.id || m.name,
});
const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
	...m,
	getEndpoint: async (): Promise<EmbeddingEndpoint> => {
		if (!m.endpoints) {
			return embeddingEndpointTransformersJS({
				type: "transformersjs",
				weight: 1,
				model: m,
			});
		}
		const totalWeight = sum(m.endpoints.map((e) => e.weight));
		let random = Math.random() * totalWeight;
		for (const endpoint of m.endpoints) {
			if (random < endpoint.weight) {
				const args = { ...endpoint, model: m };
				switch (args.type) {
					case "tei":
						return embeddingEndpoints.tei(args);
					case "transformersjs":
						return embeddingEndpoints.transformersjs(args);
					case "openai":
						return embeddingEndpoints.openai(args);
					case "hfapi":
						return embeddingEndpoints.hfapi(args);
					default:
						throw new Error(`Unknown endpoint type: ${args}`);
				}
			}
			random -= endpoint.weight;
		}
		throw new Error(`Failed to select embedding endpoint`);
	},
});
export const embeddingModels = await Promise.all(
	embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))
);
export const defaultEmbeddingModel = embeddingModels[0];
const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => {
	return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]);
};
export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => {
	return validateEmbeddingModel(_models, "id");
};
export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => {
	return validateEmbeddingModel(_models, "name");
};
export type EmbeddingBackendModel = typeof defaultEmbeddingModel;
 | 
 
			
