Spaces:
Running
Running
File size: 3,102 Bytes
3a01622 41f8b74 3a01622 6e0b0ea 3a01622 f7db219 3a01622 6e0b0ea 3a01622 f7db219 04b77a5 3a01622 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import { TEXT_EMBEDDING_MODELS } from "$env/static/private";
import { z } from "zod";
import { sum } from "$lib/utils/sum";
import {
embeddingEndpoints,
embeddingEndpointSchema,
type EmbeddingEndpoint,
} from "$lib/server/embeddingEndpoints/embeddingEndpoints";
import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";
import JSON5 from "json5";
const modelConfig = z.object({
/** Used as an identifier in DB */
id: z.string().optional(),
/** Used to link to the model page, and for inference */
name: z.string().min(1),
displayName: z.string().min(1).optional(),
description: z.string().min(1).optional(),
websiteUrl: z.string().url().optional(),
modelUrl: z.string().url().optional(),
endpoints: z.array(embeddingEndpointSchema).nonempty(),
chunkCharLength: z.number().positive(),
maxBatchSize: z.number().positive().optional(),
preQuery: z.string().default(""),
prePassage: z.string().default(""),
});
// Default embedding model for backward compatibility
const rawEmbeddingModelJSON =
TEXT_EMBEDDING_MODELS ||
`[
{
"name": "Xenova/gte-small",
"chunkCharLength": 512,
"endpoints": [
{ "type": "transformersjs" }
]
}
]`;
const embeddingModelsRaw = z.array(modelConfig).parse(JSON5.parse(rawEmbeddingModelJSON));
const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({
...m,
id: m.id || m.name,
});
const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
...m,
getEndpoint: async (): Promise<EmbeddingEndpoint> => {
if (!m.endpoints) {
return embeddingEndpointTransformersJS({
type: "transformersjs",
weight: 1,
model: m,
});
}
const totalWeight = sum(m.endpoints.map((e) => e.weight));
let random = Math.random() * totalWeight;
for (const endpoint of m.endpoints) {
if (random < endpoint.weight) {
const args = { ...endpoint, model: m };
switch (args.type) {
case "tei":
return embeddingEndpoints.tei(args);
case "transformersjs":
return embeddingEndpoints.transformersjs(args);
case "openai":
return embeddingEndpoints.openai(args);
case "hfapi":
return embeddingEndpoints.hfapi(args);
default:
throw new Error(`Unknown endpoint type: ${args}`);
}
}
random -= endpoint.weight;
}
throw new Error(`Failed to select embedding endpoint`);
},
});
export const embeddingModels = await Promise.all(
embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))
);
export const defaultEmbeddingModel = embeddingModels[0];
const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => {
return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]);
};
export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => {
return validateEmbeddingModel(_models, "id");
};
export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => {
return validateEmbeddingModel(_models, "name");
};
export type EmbeddingBackendModel = typeof defaultEmbeddingModel;
|