MichaelFried Mishig nsarrazin HF staff mishig HF staff commited on
Commit
3a01622
1 Parent(s): 69c0464

Add embedding models configurable, from both transformers.js and TEI (#646)

Browse files

* Add embedding models configurable, from both Xenova and TEI

* fix lint and format

* Fix bug in sentenceSimilarity

* Batches for TEI using /info route

* Fix web search disapear when finish searching

* Fix lint and format

* Add more options for better embedding model usage

* Fixing CR issues

* Fix websearch disapear in later PR

* Fix lint

* Fix more minor code CR

* Valiadate embeddingModelName field in model config

* Add embeddingModel into shared conversation

* Fix lint and format

* Add default embedding model, and more readme explanation

* Fix minor embedding model readme detailed

* Update settings.json

* Update README.md

Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>

* Update README.md

Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>

* Apply suggestions from code review

Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>

* Resolved more issues

* lint

* Fix more issues

* Fix format

* fix small typo

* lint

* fix default model

* Rn `maxSequenceLength` -> `chunkCharLength`

* format

* add "authorization" example

* format

---------

Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
Co-authored-by: Nathan Sarrazin <sarrazin.nathan@gmail.com>
Co-authored-by: Mishig Davaadorj <dmishig@gmail.com>

.env CHANGED
@@ -46,6 +46,18 @@ CA_PATH=#
46
  CLIENT_KEY_PASSWORD=#
47
  REJECT_UNAUTHORIZED=true
48
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # 'name', 'userMessageToken', 'assistantMessageToken' are required
50
  MODELS=`[
51
  {
 
46
  CLIENT_KEY_PASSWORD=#
47
  REJECT_UNAUTHORIZED=true
48
 
49
+ TEXT_EMBEDDING_MODELS = `[
50
+ {
51
+ "name": "Xenova/gte-small",
52
+ "displayName": "Xenova/gte-small",
53
+ "description": "Local embedding model running on the server.",
54
+ "chunkCharLength": 512,
55
+ "endpoints": [
56
+ { "type": "transformersjs" }
57
+ ]
58
+ }
59
+ ]`
60
+
61
  # 'name', 'userMessageToken', 'assistantMessageToken' are required
62
  MODELS=`[
63
  {
.env.template CHANGED
@@ -204,7 +204,6 @@ TASK_MODEL='mistralai/Mistral-7B-Instruct-v0.2'
204
  # "stop": ["</s>"]
205
  # }}`
206
 
207
-
208
  APP_BASE="/chat"
209
  PUBLIC_ORIGIN=https://huggingface.co
210
  PUBLIC_SHARE_PREFIX=https://hf.co/chat
 
204
  # "stop": ["</s>"]
205
  # }}`
206
 
 
207
  APP_BASE="/chat"
208
  PUBLIC_ORIGIN=https://huggingface.co
209
  PUBLIC_SHARE_PREFIX=https://hf.co/chat
README.md CHANGED
@@ -20,9 +20,10 @@ A chat interface using open source models, eg OpenAssistant or Llama. It is a Sv
20
  1. [Setup](#setup)
21
  2. [Launch](#launch)
22
  3. [Web Search](#web-search)
23
- 4. [Extra parameters](#extra-parameters)
24
- 5. [Deploying to a HF Space](#deploying-to-a-hf-space)
25
- 6. [Building](#building)
 
26
 
27
  ## No Setup Deploy
28
 
@@ -78,10 +79,50 @@ Chat UI features a powerful Web Search feature. It works by:
78
 
79
  1. Generating an appropriate search query from the user prompt.
80
  2. Performing web search and extracting content from webpages.
81
- 3. Creating embeddings from texts using [transformers.js](https://huggingface.co/docs/transformers.js). Specifically, using [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model.
82
  4. From these embeddings, find the ones that are closest to the user query using a vector similarity search. Specifically, we use `inner product` distance.
83
  5. Get the corresponding texts to those closest embeddings and perform [Retrieval-Augmented Generation](https://huggingface.co/papers/2005.11401) (i.e. expand user prompt by adding those texts so that an LLM can use this information).
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  ## Extra parameters
86
 
87
  ### OpenID connect
@@ -425,6 +466,45 @@ If you're using a certificate signed by a private CA, you will also need to add
425
 
426
  If you're using a self-signed certificate, e.g. for testing or development purposes, you can set the `REJECT_UNAUTHORIZED` parameter to `false` in your `.env.local`. This will disable certificate validation, and allow Chat UI to connect to your custom endpoint.
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  ## Deploying to a HF Space
429
 
430
  Create a `DOTENV_LOCAL` secret to your HF space with the content of your .env.local, and they will be picked up automatically when you run.
 
20
  1. [Setup](#setup)
21
  2. [Launch](#launch)
22
  3. [Web Search](#web-search)
23
+ 4. [Text Embedding Models](#text-embedding-models)
24
+ 5. [Extra parameters](#extra-parameters)
25
+ 6. [Deploying to a HF Space](#deploying-to-a-hf-space)
26
+ 7. [Building](#building)
27
 
28
  ## No Setup Deploy
29
 
 
79
 
80
  1. Generating an appropriate search query from the user prompt.
81
  2. Performing web search and extracting content from webpages.
82
+ 3. Creating embeddings from texts using a text embedding model.
83
  4. From these embeddings, find the ones that are closest to the user query using a vector similarity search. Specifically, we use `inner product` distance.
84
  5. Get the corresponding texts to those closest embeddings and perform [Retrieval-Augmented Generation](https://huggingface.co/papers/2005.11401) (i.e. expand user prompt by adding those texts so that an LLM can use this information).
85
 
86
+ ## Text Embedding Models
87
+
88
+ By default (for backward compatibility), when `TEXT_EMBEDDING_MODELS` environment variable is not defined, [transformers.js](https://huggingface.co/docs/transformers.js) embedding models will be used for embedding tasks, specifically, [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model.
89
+
90
+ You can customize the embedding model by setting `TEXT_EMBEDDING_MODELS` in your `.env.local` file. For example:
91
+
92
+ ```env
93
+ TEXT_EMBEDDING_MODELS = `[
94
+ {
95
+ "name": "Xenova/gte-small",
96
+ "displayName": "Xenova/gte-small",
97
+ "description": "locally running embedding",
98
+ "chunkCharLength": 512,
99
+ "endpoints": [
100
+ {"type": "transformersjs"}
101
+ ]
102
+ },
103
+ {
104
+ "name": "intfloat/e5-base-v2",
105
+ "displayName": "intfloat/e5-base-v2",
106
+ "description": "hosted embedding model",
107
+ "chunkCharLength": 768,
108
+ "preQuery": "query: ", # See https://huggingface.co/intfloat/e5-base-v2#faq
109
+ "prePassage": "passage: ", # See https://huggingface.co/intfloat/e5-base-v2#faq
110
+ "endpoints": [
111
+ {
112
+ "type": "tei",
113
+ "url": "http://127.0.0.1:8080/",
114
+ "authorization": "TOKEN_TYPE TOKEN" // optional authorization field. Example: "Basic VVNFUjpQQVNT"
115
+ }
116
+ ]
117
+ }
118
+ ]`
119
+ ```
120
+
121
+ The required fields are `name`, `chunkCharLength` and `endpoints`.
122
+ Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js) and [`TEI`](https://github.com/huggingface/text-embeddings-inference). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint.
123
+
124
+ When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModel` to the name of the model.
125
+
126
  ## Extra parameters
127
 
128
  ### OpenID connect
 
466
 
467
  If you're using a self-signed certificate, e.g. for testing or development purposes, you can set the `REJECT_UNAUTHORIZED` parameter to `false` in your `.env.local`. This will disable certificate validation, and allow Chat UI to connect to your custom endpoint.
468
 
469
+ #### Specific Embedding Model
470
+
471
+ A model can use any of the embedding models defined in `.env.local`, (currently used when web searching),
472
+ by default it will use the first embedding model, but it can be changed with the field `embeddingModel`:
473
+
474
+ ```env
475
+ TEXT_EMBEDDING_MODELS = `[
476
+ {
477
+ "name": "Xenova/gte-small",
478
+ "chunkCharLength": 512,
479
+ "endpoints": [
480
+ {"type": "transformersjs"}
481
+ ]
482
+ },
483
+ {
484
+ "name": "intfloat/e5-base-v2",
485
+ "chunkCharLength": 768,
486
+ "endpoints": [
487
+ {"type": "tei", "url": "http://127.0.0.1:8080/", "authorization": "Basic VVNFUjpQQVNT"},
488
+ {"type": "tei", "url": "http://127.0.0.1:8081/"}
489
+ ]
490
+ }
491
+ ]`
492
+
493
+ MODELS=`[
494
+ {
495
+ "name": "Ollama Mistral",
496
+ "chatPromptTemplate": "...",
497
+ "embeddingModel": "intfloat/e5-base-v2"
498
+ "parameters": {
499
+ ...
500
+ },
501
+ "endpoints": [
502
+ ...
503
+ ]
504
+ }
505
+ ]`
506
+ ```
507
+
508
  ## Deploying to a HF Space
509
 
510
  Create a `DOTENV_LOCAL` secret to your HF space with the content of your .env.local, and they will be picked up automatically when you run.
src/lib/components/OpenWebSearchResults.svelte CHANGED
@@ -30,8 +30,8 @@
30
  {:else}
31
  <CarbonCheckmark class="my-auto text-gray-500" />
32
  {/if}
33
- <span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}
34
- >Web search
35
  </span>
36
  <div class="my-auto transition-all" class:rotate-90={detailsOpen}>
37
  <CarbonCaretRight />
 
30
  {:else}
31
  <CarbonCheckmark class="my-auto text-gray-500" />
32
  {/if}
33
+ <span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}>
34
+ Web search
35
  </span>
36
  <div class="my-auto transition-all" class:rotate-90={detailsOpen}>
37
  <CarbonCaretRight />
src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { z } from "zod";
2
+ import type { EmbeddingEndpoint, Embedding } from "$lib/types/EmbeddingEndpoints";
3
+ import { chunk } from "$lib/utils/chunk";
4
+
5
+ export const embeddingEndpointTeiParametersSchema = z.object({
6
+ weight: z.number().int().positive().default(1),
7
+ model: z.any(),
8
+ type: z.literal("tei"),
9
+ url: z.string().url(),
10
+ authorization: z.string().optional(),
11
+ });
12
+
13
+ const getModelInfoByUrl = async (url: string, authorization?: string) => {
14
+ const { origin } = new URL(url);
15
+
16
+ const response = await fetch(`${origin}/info`, {
17
+ headers: {
18
+ Accept: "application/json",
19
+ "Content-Type": "application/json",
20
+ ...(authorization ? { Authorization: authorization } : {}),
21
+ },
22
+ });
23
+
24
+ const json = await response.json();
25
+ return json;
26
+ };
27
+
28
+ export async function embeddingEndpointTei(
29
+ input: z.input<typeof embeddingEndpointTeiParametersSchema>
30
+ ): Promise<EmbeddingEndpoint> {
31
+ const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input);
32
+
33
+ const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url);
34
+ const maxBatchSize = Math.min(
35
+ max_client_batch_size,
36
+ Math.floor(max_batch_tokens / model.chunkCharLength)
37
+ );
38
+
39
+ return async ({ inputs }) => {
40
+ const { origin } = new URL(url);
41
+
42
+ const batchesInputs = chunk(inputs, maxBatchSize);
43
+
44
+ const batchesResults = await Promise.all(
45
+ batchesInputs.map(async (batchInputs) => {
46
+ const response = await fetch(`${origin}/embed`, {
47
+ method: "POST",
48
+ headers: {
49
+ Accept: "application/json",
50
+ "Content-Type": "application/json",
51
+ ...(authorization ? { Authorization: authorization } : {}),
52
+ },
53
+ body: JSON.stringify({ inputs: batchInputs, normalize: true, truncate: true }),
54
+ });
55
+
56
+ const embeddings: Embedding[] = await response.json();
57
+ return embeddings;
58
+ })
59
+ );
60
+
61
+ const flatAllEmbeddings = batchesResults.flat();
62
+
63
+ return flatAllEmbeddings;
64
+ };
65
+ }
src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { z } from "zod";
2
+ import type { EmbeddingEndpoint } from "$lib/types/EmbeddingEndpoints";
3
+ import type { Tensor, Pipeline } from "@xenova/transformers";
4
+ import { pipeline } from "@xenova/transformers";
5
+
6
+ export const embeddingEndpointTransformersJSParametersSchema = z.object({
7
+ weight: z.number().int().positive().default(1),
8
+ model: z.any(),
9
+ type: z.literal("transformersjs"),
10
+ });
11
+
12
+ // Use the Singleton pattern to enable lazy construction of the pipeline.
13
+ class TransformersJSModelsSingleton {
14
+ static instances: Array<[string, Promise<Pipeline>]> = [];
15
+
16
+ static async getInstance(modelName: string): Promise<Pipeline> {
17
+ const modelPipelineInstance = this.instances.find(([name]) => name === modelName);
18
+
19
+ if (modelPipelineInstance) {
20
+ const [, modelPipeline] = modelPipelineInstance;
21
+ return modelPipeline;
22
+ }
23
+
24
+ const newModelPipeline = pipeline("feature-extraction", modelName);
25
+ this.instances.push([modelName, newModelPipeline]);
26
+
27
+ return newModelPipeline;
28
+ }
29
+ }
30
+
31
+ export async function calculateEmbedding(modelName: string, inputs: string[]) {
32
+ const extractor = await TransformersJSModelsSingleton.getInstance(modelName);
33
+ const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true });
34
+
35
+ return output.tolist();
36
+ }
37
+
38
+ export function embeddingEndpointTransformersJS(
39
+ input: z.input<typeof embeddingEndpointTransformersJSParametersSchema>
40
+ ): EmbeddingEndpoint {
41
+ const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input);
42
+
43
+ return async ({ inputs }) => {
44
+ return calculateEmbedding(model.name, inputs);
45
+ };
46
+ }
src/lib/server/embeddingModels.ts ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { TEXT_EMBEDDING_MODELS } from "$env/static/private";
2
+
3
+ import { z } from "zod";
4
+ import { sum } from "$lib/utils/sum";
5
+ import {
6
+ embeddingEndpoints,
7
+ embeddingEndpointSchema,
8
+ type EmbeddingEndpoint,
9
+ } from "$lib/types/EmbeddingEndpoints";
10
+ import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";
11
+
12
+ const modelConfig = z.object({
13
+ /** Used as an identifier in DB */
14
+ id: z.string().optional(),
15
+ /** Used to link to the model page, and for inference */
16
+ name: z.string().min(1),
17
+ displayName: z.string().min(1).optional(),
18
+ description: z.string().min(1).optional(),
19
+ websiteUrl: z.string().url().optional(),
20
+ modelUrl: z.string().url().optional(),
21
+ endpoints: z.array(embeddingEndpointSchema).nonempty(),
22
+ chunkCharLength: z.number().positive(),
23
+ preQuery: z.string().default(""),
24
+ prePassage: z.string().default(""),
25
+ });
26
+
27
+ // Default embedding model for backward compatibility
28
+ const rawEmbeddingModelJSON =
29
+ TEXT_EMBEDDING_MODELS ||
30
+ `[
31
+ {
32
+ "name": "Xenova/gte-small",
33
+ "chunkCharLength": 512,
34
+ "endpoints": [
35
+ { "type": "transformersjs" }
36
+ ]
37
+ }
38
+ ]`;
39
+
40
+ const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(rawEmbeddingModelJSON));
41
+
42
+ const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({
43
+ ...m,
44
+ id: m.id || m.name,
45
+ });
46
+
47
+ const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
48
+ ...m,
49
+ getEndpoint: async (): Promise<EmbeddingEndpoint> => {
50
+ if (!m.endpoints) {
51
+ return embeddingEndpointTransformersJS({
52
+ type: "transformersjs",
53
+ weight: 1,
54
+ model: m,
55
+ });
56
+ }
57
+
58
+ const totalWeight = sum(m.endpoints.map((e) => e.weight));
59
+
60
+ let random = Math.random() * totalWeight;
61
+
62
+ for (const endpoint of m.endpoints) {
63
+ if (random < endpoint.weight) {
64
+ const args = { ...endpoint, model: m };
65
+
66
+ switch (args.type) {
67
+ case "tei":
68
+ return embeddingEndpoints.tei(args);
69
+ case "transformersjs":
70
+ return embeddingEndpoints.transformersjs(args);
71
+ }
72
+ }
73
+
74
+ random -= endpoint.weight;
75
+ }
76
+
77
+ throw new Error(`Failed to select embedding endpoint`);
78
+ },
79
+ });
80
+
81
+ export const embeddingModels = await Promise.all(
82
+ embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))
83
+ );
84
+
85
+ export const defaultEmbeddingModel = embeddingModels[0];
86
+
87
+ const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => {
88
+ return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]);
89
+ };
90
+
91
+ export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => {
92
+ return validateEmbeddingModel(_models, "id");
93
+ };
94
+
95
+ export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => {
96
+ return validateEmbeddingModel(_models, "name");
97
+ };
98
+
99
+ export type EmbeddingBackendModel = typeof defaultEmbeddingModel;
src/lib/server/models.ts CHANGED
@@ -12,6 +12,7 @@ import { z } from "zod";
12
  import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints";
13
  import endpointTgi from "./endpoints/tgi/endpointTgi";
14
  import { sum } from "$lib/utils/sum";
 
15
 
16
  import JSON5 from "json5";
17
 
@@ -68,6 +69,7 @@ const modelConfig = z.object({
68
  .optional(),
69
  multimodal: z.boolean().default(false),
70
  unlisted: z.boolean().default(false),
 
71
  });
72
 
73
  const modelsRaw = z.array(modelConfig).parse(JSON5.parse(MODELS));
 
12
  import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints";
13
  import endpointTgi from "./endpoints/tgi/endpointTgi";
14
  import { sum } from "$lib/utils/sum";
15
+ import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels";
16
 
17
  import JSON5 from "json5";
18
 
 
69
  .optional(),
70
  multimodal: z.boolean().default(false),
71
  unlisted: z.boolean().default(false),
72
+ embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(),
73
  });
74
 
75
  const modelsRaw = z.array(modelConfig).parse(JSON5.parse(MODELS));
src/lib/server/{websearch/sentenceSimilarity.ts → sentenceSimilarity.ts} RENAMED
@@ -1,43 +1,33 @@
1
- import type { Tensor, Pipeline } from "@xenova/transformers";
2
- import { pipeline, dot } from "@xenova/transformers";
 
3
 
4
  // see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34
5
- function innerProduct(tensor1: Tensor, tensor2: Tensor) {
6
- return 1.0 - dot(tensor1.data, tensor2.data);
7
  }
8
 
9
- // Use the Singleton pattern to enable lazy construction of the pipeline.
10
- class PipelineSingleton {
11
- static modelId = "Xenova/gte-small";
12
- static instance: Promise<Pipeline> | null = null;
13
- static async getInstance() {
14
- if (this.instance === null) {
15
- this.instance = pipeline("feature-extraction", this.modelId);
16
- }
17
- return this.instance;
18
- }
19
- }
20
-
21
- // see https://huggingface.co/thenlper/gte-small/blob/d8e2604cadbeeda029847d19759d219e0ce2e6d8/README.md?code=true#L2625
22
- export const MAX_SEQ_LEN = 512 as const;
23
-
24
  export async function findSimilarSentences(
 
25
  query: string,
26
  sentences: string[],
27
  { topK = 5 }: { topK: number }
28
- ) {
29
- const input = [query, ...sentences];
 
 
 
30
 
31
- const extractor = await PipelineSingleton.getInstance();
32
- const output: Tensor = await extractor(input, { pooling: "mean", normalize: true });
33
 
34
- const queryTensor: Tensor = output[0];
35
- const sentencesTensor: Tensor = output.slice([1, input.length - 1]);
36
 
37
- const distancesFromQuery: { distance: number; index: number }[] = [...sentencesTensor].map(
38
- (sentenceTensor: Tensor, index: number) => {
39
  return {
40
- distance: innerProduct(queryTensor, sentenceTensor),
41
  index: index,
42
  };
43
  }
 
1
+ import { dot } from "@xenova/transformers";
2
+ import type { EmbeddingBackendModel } from "$lib/server/embeddingModels";
3
+ import type { Embedding } from "$lib/types/EmbeddingEndpoints";
4
 
5
  // see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34
6
+ function innerProduct(embeddingA: Embedding, embeddingB: Embedding) {
7
+ return 1.0 - dot(embeddingA, embeddingB);
8
  }
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  export async function findSimilarSentences(
11
+ embeddingModel: EmbeddingBackendModel,
12
  query: string,
13
  sentences: string[],
14
  { topK = 5 }: { topK: number }
15
+ ): Promise<Embedding> {
16
+ const inputs = [
17
+ `${embeddingModel.preQuery}${query}`,
18
+ ...sentences.map((sentence) => `${embeddingModel.prePassage}${sentence}`),
19
+ ];
20
 
21
+ const embeddingEndpoint = await embeddingModel.getEndpoint();
22
+ const output = await embeddingEndpoint({ inputs });
23
 
24
+ const queryEmbedding: Embedding = output[0];
25
+ const sentencesEmbeddings: Embedding[] = output.slice(1, inputs.length - 1);
26
 
27
+ const distancesFromQuery: { distance: number; index: number }[] = [...sentencesEmbeddings].map(
28
+ (sentenceEmbedding: Embedding, index: number) => {
29
  return {
30
+ distance: innerProduct(queryEmbedding, sentenceEmbedding),
31
  index: index,
32
  };
33
  }
src/lib/server/websearch/runWebSearch.ts CHANGED
@@ -4,13 +4,11 @@ import type { WebSearch, WebSearchSource } from "$lib/types/WebSearch";
4
  import { generateQuery } from "$lib/server/websearch/generateQuery";
5
  import { parseWeb } from "$lib/server/websearch/parseWeb";
6
  import { chunk } from "$lib/utils/chunk";
7
- import {
8
- MAX_SEQ_LEN as CHUNK_CAR_LEN,
9
- findSimilarSentences,
10
- } from "$lib/server/websearch/sentenceSimilarity";
11
  import type { Conversation } from "$lib/types/Conversation";
12
  import type { MessageUpdate } from "$lib/types/MessageUpdate";
13
  import { getWebSearchProvider } from "./searchWeb";
 
14
 
15
  const MAX_N_PAGES_SCRAPE = 10 as const;
16
  const MAX_N_PAGES_EMBED = 5 as const;
@@ -63,6 +61,14 @@ export async function runWebSearch(
63
  .filter(({ link }) => !DOMAIN_BLOCKLIST.some((el) => link.includes(el))) // filter out blocklist links
64
  .slice(0, MAX_N_PAGES_SCRAPE); // limit to first 10 links only
65
 
 
 
 
 
 
 
 
 
66
  let paragraphChunks: { source: WebSearchSource; text: string }[] = [];
67
  if (webSearch.results.length > 0) {
68
  appendUpdate("Browsing results");
@@ -78,7 +84,7 @@ export async function runWebSearch(
78
  }
79
  }
80
  const MAX_N_CHUNKS = 100;
81
- const texts = chunk(text, CHUNK_CAR_LEN).slice(0, MAX_N_CHUNKS);
82
  return texts.map((t) => ({ source: result, text: t }));
83
  });
84
  const nestedParagraphChunks = (await Promise.all(promises)).slice(0, MAX_N_PAGES_EMBED);
@@ -93,7 +99,7 @@ export async function runWebSearch(
93
  appendUpdate("Extracting relevant information");
94
  const topKClosestParagraphs = 8;
95
  const texts = paragraphChunks.map(({ text }) => text);
96
- const indices = await findSimilarSentences(prompt, texts, {
97
  topK: topKClosestParagraphs,
98
  });
99
  webSearch.context = indices.map((idx) => texts[idx]).join("");
 
4
  import { generateQuery } from "$lib/server/websearch/generateQuery";
5
  import { parseWeb } from "$lib/server/websearch/parseWeb";
6
  import { chunk } from "$lib/utils/chunk";
7
+ import { findSimilarSentences } from "$lib/server/sentenceSimilarity";
 
 
 
8
  import type { Conversation } from "$lib/types/Conversation";
9
  import type { MessageUpdate } from "$lib/types/MessageUpdate";
10
  import { getWebSearchProvider } from "./searchWeb";
11
+ import { defaultEmbeddingModel, embeddingModels } from "$lib/server/embeddingModels";
12
 
13
  const MAX_N_PAGES_SCRAPE = 10 as const;
14
  const MAX_N_PAGES_EMBED = 5 as const;
 
61
  .filter(({ link }) => !DOMAIN_BLOCKLIST.some((el) => link.includes(el))) // filter out blocklist links
62
  .slice(0, MAX_N_PAGES_SCRAPE); // limit to first 10 links only
63
 
64
+ // fetch the model
65
+ const embeddingModel =
66
+ embeddingModels.find((m) => m.id === conv.embeddingModel) ?? defaultEmbeddingModel;
67
+
68
+ if (!embeddingModel) {
69
+ throw new Error(`Embedding model ${conv.embeddingModel} not available anymore`);
70
+ }
71
+
72
  let paragraphChunks: { source: WebSearchSource; text: string }[] = [];
73
  if (webSearch.results.length > 0) {
74
  appendUpdate("Browsing results");
 
84
  }
85
  }
86
  const MAX_N_CHUNKS = 100;
87
+ const texts = chunk(text, embeddingModel.chunkCharLength).slice(0, MAX_N_CHUNKS);
88
  return texts.map((t) => ({ source: result, text: t }));
89
  });
90
  const nestedParagraphChunks = (await Promise.all(promises)).slice(0, MAX_N_PAGES_EMBED);
 
99
  appendUpdate("Extracting relevant information");
100
  const topKClosestParagraphs = 8;
101
  const texts = paragraphChunks.map(({ text }) => text);
102
+ const indices = await findSimilarSentences(embeddingModel, prompt, texts, {
103
  topK: topKClosestParagraphs,
104
  });
105
  webSearch.context = indices.map((idx) => texts[idx]).join("");
src/lib/types/Conversation.ts CHANGED
@@ -10,6 +10,7 @@ export interface Conversation extends Timestamps {
10
  userId?: User["_id"];
11
 
12
  model: string;
 
13
 
14
  title: string;
15
  messages: Message[];
 
10
  userId?: User["_id"];
11
 
12
  model: string;
13
+ embeddingModel: string;
14
 
15
  title: string;
16
  messages: Message[];
src/lib/types/EmbeddingEndpoints.ts ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { z } from "zod";
2
+ import {
3
+ embeddingEndpointTei,
4
+ embeddingEndpointTeiParametersSchema,
5
+ } from "$lib/server/embeddingEndpoints/tei/embeddingEndpoints";
6
+ import {
7
+ embeddingEndpointTransformersJS,
8
+ embeddingEndpointTransformersJSParametersSchema,
9
+ } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";
10
+
11
+ // parameters passed when generating text
12
+ interface EmbeddingEndpointParameters {
13
+ inputs: string[];
14
+ }
15
+
16
+ export type Embedding = number[];
17
+
18
+ // type signature for the endpoint
19
+ export type EmbeddingEndpoint = (params: EmbeddingEndpointParameters) => Promise<Embedding[]>;
20
+
21
+ export const embeddingEndpointSchema = z.discriminatedUnion("type", [
22
+ embeddingEndpointTeiParametersSchema,
23
+ embeddingEndpointTransformersJSParametersSchema,
24
+ ]);
25
+
26
+ type EmbeddingEndpointTypeOptions = z.infer<typeof embeddingEndpointSchema>["type"];
27
+
28
+ // generator function that takes in type discrimantor value for defining the endpoint and return the endpoint
29
+ export type EmbeddingEndpointGenerator<T extends EmbeddingEndpointTypeOptions> = (
30
+ inputs: Extract<z.infer<typeof embeddingEndpointSchema>, { type: T }>
31
+ ) => EmbeddingEndpoint | Promise<EmbeddingEndpoint>;
32
+
33
+ // list of all endpoint generators
34
+ export const embeddingEndpoints: {
35
+ [Key in EmbeddingEndpointTypeOptions]: EmbeddingEndpointGenerator<Key>;
36
+ } = {
37
+ tei: embeddingEndpointTei,
38
+ transformersjs: embeddingEndpointTransformersJS,
39
+ };
40
+
41
+ export default embeddingEndpoints;
src/lib/types/SharedConversation.ts CHANGED
@@ -7,6 +7,8 @@ export interface SharedConversation extends Timestamps {
7
  hash: string;
8
 
9
  model: string;
 
 
10
  title: string;
11
  messages: Message[];
12
  preprompt?: string;
 
7
  hash: string;
8
 
9
  model: string;
10
+ embeddingModel: string;
11
+
12
  title: string;
13
  messages: Message[];
14
  preprompt?: string;
src/routes/conversation/+server.ts CHANGED
@@ -6,6 +6,7 @@ import { base } from "$app/paths";
6
  import { z } from "zod";
7
  import type { Message } from "$lib/types/Message";
8
  import { models, validateModel } from "$lib/server/models";
 
9
 
10
  export const POST: RequestHandler = async ({ locals, request }) => {
11
  const body = await request.text();
@@ -22,6 +23,7 @@ export const POST: RequestHandler = async ({ locals, request }) => {
22
  .parse(JSON.parse(body));
23
 
24
  let preprompt = values.preprompt;
 
25
 
26
  if (values.fromShare) {
27
  const conversation = await collections.sharedConversations.findOne({
@@ -35,6 +37,7 @@ export const POST: RequestHandler = async ({ locals, request }) => {
35
  title = conversation.title;
36
  messages = conversation.messages;
37
  values.model = conversation.model;
 
38
  preprompt = conversation.preprompt;
39
  }
40
 
@@ -44,6 +47,8 @@ export const POST: RequestHandler = async ({ locals, request }) => {
44
  throw error(400, "Invalid model");
45
  }
46
 
 
 
47
  if (model.unlisted) {
48
  throw error(400, "Can't start a conversation with an unlisted model");
49
  }
@@ -59,6 +64,7 @@ export const POST: RequestHandler = async ({ locals, request }) => {
59
  preprompt: preprompt === model?.preprompt ? model?.preprompt : preprompt,
60
  createdAt: new Date(),
61
  updatedAt: new Date(),
 
62
  ...(locals.user ? { userId: locals.user._id } : { sessionId: locals.sessionId }),
63
  ...(values.fromShare ? { meta: { fromShareId: values.fromShare } } : {}),
64
  });
 
6
  import { z } from "zod";
7
  import type { Message } from "$lib/types/Message";
8
  import { models, validateModel } from "$lib/server/models";
9
+ import { defaultEmbeddingModel } from "$lib/server/embeddingModels";
10
 
11
  export const POST: RequestHandler = async ({ locals, request }) => {
12
  const body = await request.text();
 
23
  .parse(JSON.parse(body));
24
 
25
  let preprompt = values.preprompt;
26
+ let embeddingModel: string;
27
 
28
  if (values.fromShare) {
29
  const conversation = await collections.sharedConversations.findOne({
 
37
  title = conversation.title;
38
  messages = conversation.messages;
39
  values.model = conversation.model;
40
+ embeddingModel = conversation.embeddingModel;
41
  preprompt = conversation.preprompt;
42
  }
43
 
 
47
  throw error(400, "Invalid model");
48
  }
49
 
50
+ embeddingModel ??= model.embeddingModel ?? defaultEmbeddingModel.name;
51
+
52
  if (model.unlisted) {
53
  throw error(400, "Can't start a conversation with an unlisted model");
54
  }
 
64
  preprompt: preprompt === model?.preprompt ? model?.preprompt : preprompt,
65
  createdAt: new Date(),
66
  updatedAt: new Date(),
67
+ embeddingModel: embeddingModel,
68
  ...(locals.user ? { userId: locals.user._id } : { sessionId: locals.sessionId }),
69
  ...(values.fromShare ? { meta: { fromShareId: values.fromShare } } : {}),
70
  });
src/routes/conversation/[id]/+page.svelte CHANGED
@@ -173,6 +173,7 @@
173
  inputs.forEach(async (el: string) => {
174
  try {
175
  const update = JSON.parse(el) as MessageUpdate;
 
176
  if (update.type === "finalAnswer") {
177
  finalAnswer = update.text;
178
  reader.cancel();
@@ -225,7 +226,7 @@
225
  });
226
  }
227
 
228
- // reset the websearchmessages
229
  webSearchMessages = [];
230
 
231
  await invalidate(UrlDependency.ConversationList);
 
173
  inputs.forEach(async (el: string) => {
174
  try {
175
  const update = JSON.parse(el) as MessageUpdate;
176
+
177
  if (update.type === "finalAnswer") {
178
  finalAnswer = update.text;
179
  reader.cancel();
 
226
  });
227
  }
228
 
229
+ // reset the websearchMessages
230
  webSearchMessages = [];
231
 
232
  await invalidate(UrlDependency.ConversationList);
src/routes/conversation/[id]/share/+server.ts CHANGED
@@ -38,6 +38,7 @@ export async function POST({ params, url, locals }) {
38
  updatedAt: new Date(),
39
  title: conversation.title,
40
  model: conversation.model,
 
41
  preprompt: conversation.preprompt,
42
  };
43
 
 
38
  updatedAt: new Date(),
39
  title: conversation.title,
40
  model: conversation.model,
41
+ embeddingModel: conversation.embeddingModel,
42
  preprompt: conversation.preprompt,
43
  };
44
 
src/routes/login/callback/updateUser.spec.ts CHANGED
@@ -6,6 +6,7 @@ import { ObjectId } from "mongodb";
6
  import { DEFAULT_SETTINGS } from "$lib/types/Settings";
7
  import { defaultModel } from "$lib/server/models";
8
  import { findUser } from "$lib/server/auth";
 
9
 
10
  const userData = {
11
  preferred_username: "new-username",
@@ -46,6 +47,7 @@ const insertRandomConversations = async (count: number) => {
46
  title: "random title",
47
  messages: [],
48
  model: defaultModel.id,
 
49
  createdAt: new Date(),
50
  updatedAt: new Date(),
51
  sessionId: locals.sessionId,
 
6
  import { DEFAULT_SETTINGS } from "$lib/types/Settings";
7
  import { defaultModel } from "$lib/server/models";
8
  import { findUser } from "$lib/server/auth";
9
+ import { defaultEmbeddingModel } from "$lib/server/embeddingModels";
10
 
11
  const userData = {
12
  preferred_username: "new-username",
 
47
  title: "random title",
48
  messages: [],
49
  model: defaultModel.id,
50
+ embeddingModel: defaultEmbeddingModel.id,
51
  createdAt: new Date(),
52
  updatedAt: new Date(),
53
  sessionId: locals.sessionId,