Spaces:
Paused
Paused
| import { env } from "$env/dynamic/private"; | |
| import { z } from "zod"; | |
| import { sum } from "$lib/utils/sum"; | |
| import { | |
| embeddingEndpoints, | |
| embeddingEndpointSchema, | |
| type EmbeddingEndpoint, | |
| } from "$lib/server/embeddingEndpoints/embeddingEndpoints"; | |
| import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints"; | |
| import JSON5 from "json5"; | |
| const modelConfig = z.object({ | |
| /** Used as an identifier in DB */ | |
| id: z.string().optional(), | |
| /** Used to link to the model page, and for inference */ | |
| name: z.string().min(1), | |
| displayName: z.string().min(1).optional(), | |
| description: z.string().min(1).optional(), | |
| websiteUrl: z.string().url().optional(), | |
| modelUrl: z.string().url().optional(), | |
| endpoints: z.array(embeddingEndpointSchema).nonempty(), | |
| chunkCharLength: z.number().positive(), | |
| maxBatchSize: z.number().positive().optional(), | |
| preQuery: z.string().default(""), | |
| prePassage: z.string().default(""), | |
| }); | |
| // Default embedding model for backward compatibility | |
| const rawEmbeddingModelJSON = | |
| env.TEXT_EMBEDDING_MODELS || | |
| `[ | |
| { | |
| "name": "Xenova/gte-small", | |
| "chunkCharLength": 512, | |
| "endpoints": [ | |
| { "type": "transformersjs" } | |
| ] | |
| } | |
| ]`; | |
| const embeddingModelsRaw = z.array(modelConfig).parse(JSON5.parse(rawEmbeddingModelJSON)); | |
| const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({ | |
| ...m, | |
| id: m.id || m.name, | |
| }); | |
| const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({ | |
| ...m, | |
| getEndpoint: async (): Promise<EmbeddingEndpoint> => { | |
| if (!m.endpoints) { | |
| return embeddingEndpointTransformersJS({ | |
| type: "transformersjs", | |
| weight: 1, | |
| model: m, | |
| }); | |
| } | |
| const totalWeight = sum(m.endpoints.map((e) => e.weight)); | |
| let random = Math.random() * totalWeight; | |
| for (const endpoint of m.endpoints) { | |
| if (random < endpoint.weight) { | |
| const args = { ...endpoint, model: m }; | |
| switch (args.type) { | |
| case "tei": | |
| return embeddingEndpoints.tei(args); | |
| case "transformersjs": | |
| return embeddingEndpoints.transformersjs(args); | |
| case "openai": | |
| return embeddingEndpoints.openai(args); | |
| case "hfapi": | |
| return embeddingEndpoints.hfapi(args); | |
| default: | |
| throw new Error(`Unknown endpoint type: ${args}`); | |
| } | |
| } | |
| random -= endpoint.weight; | |
| } | |
| throw new Error(`Failed to select embedding endpoint`); | |
| }, | |
| }); | |
| export const embeddingModels = await Promise.all( | |
| embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint)) | |
| ); | |
| export const defaultEmbeddingModel = embeddingModels[0]; | |
| const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => { | |
| return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]); | |
| }; | |
| export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => { | |
| return validateEmbeddingModel(_models, "id"); | |
| }; | |
| export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => { | |
| return validateEmbeddingModel(_models, "name"); | |
| }; | |
| export type EmbeddingBackendModel = typeof defaultEmbeddingModel; | |