Martok88 nsarrazin HF staff commited on
Commit
f7db219
1 Parent(s): 0081568

Add openai embeddings (#915)

Browse files

* Add OpenAI embedding compatibility

* Use OPENAI_API_KEY by default

* lint

* Add default OpenAI URL
replace `authorization` by `apiKey`

* Add a note in readme

---------

Co-authored-by: Nathan Sarrazin <sarrazin.nathan@gmail.com>

README.md CHANGED
@@ -120,7 +120,7 @@ TEXT_EMBEDDING_MODELS = `[
120
  ```
121
 
122
  The required fields are `name`, `chunkCharLength` and `endpoints`.
123
- Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js) and [`TEI`](https://github.com/huggingface/text-embeddings-inference). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint.
124
 
125
  When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModel` to the name of the model.
126
 
 
120
  ```
121
 
122
  The required fields are `name`, `chunkCharLength` and `endpoints`.
123
+ Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js), [`TEI`](https://github.com/huggingface/text-embeddings-inference) and [`OpenAI`](https://platform.openai.com/docs/guides/embeddings). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint. `openai` models are accessed through the [OpenAI API](https://platform.openai.com/docs/guides/embeddings).
124
 
125
  When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModel` to the name of the model.
126
 
src/lib/server/embeddingEndpoints/embeddingEndpoints.ts CHANGED
@@ -7,6 +7,10 @@ import {
7
  embeddingEndpointTransformersJS,
8
  embeddingEndpointTransformersJSParametersSchema,
9
  } from "./transformersjs/embeddingEndpoints";
 
 
 
 
10
 
11
  // parameters passed when generating text
12
  interface EmbeddingEndpointParameters {
@@ -21,6 +25,7 @@ export type EmbeddingEndpoint = (params: EmbeddingEndpointParameters) => Promise
21
  export const embeddingEndpointSchema = z.discriminatedUnion("type", [
22
  embeddingEndpointTeiParametersSchema,
23
  embeddingEndpointTransformersJSParametersSchema,
 
24
  ]);
25
 
26
  type EmbeddingEndpointTypeOptions = z.infer<typeof embeddingEndpointSchema>["type"];
@@ -36,6 +41,7 @@ export const embeddingEndpoints: {
36
  } = {
37
  tei: embeddingEndpointTei,
38
  transformersjs: embeddingEndpointTransformersJS,
 
39
  };
40
 
41
  export default embeddingEndpoints;
 
7
  embeddingEndpointTransformersJS,
8
  embeddingEndpointTransformersJSParametersSchema,
9
  } from "./transformersjs/embeddingEndpoints";
10
+ import {
11
+ embeddingEndpointOpenAI,
12
+ embeddingEndpointOpenAIParametersSchema,
13
+ } from "./openai/embeddingEndpoints";
14
 
15
  // parameters passed when generating text
16
  interface EmbeddingEndpointParameters {
 
25
  export const embeddingEndpointSchema = z.discriminatedUnion("type", [
26
  embeddingEndpointTeiParametersSchema,
27
  embeddingEndpointTransformersJSParametersSchema,
28
+ embeddingEndpointOpenAIParametersSchema,
29
  ]);
30
 
31
  type EmbeddingEndpointTypeOptions = z.infer<typeof embeddingEndpointSchema>["type"];
 
41
  } = {
42
  tei: embeddingEndpointTei,
43
  transformersjs: embeddingEndpointTransformersJS,
44
+ openai: embeddingEndpointOpenAI,
45
  };
46
 
47
  export default embeddingEndpoints;
src/lib/server/embeddingEndpoints/openai/embeddingEndpoints.ts ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { z } from "zod";
2
+ import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
3
+ import { chunk } from "$lib/utils/chunk";
4
+ import { OPENAI_API_KEY } from "$env/static/private";
5
+
6
+ export const embeddingEndpointOpenAIParametersSchema = z.object({
7
+ weight: z.number().int().positive().default(1),
8
+ model: z.any(),
9
+ type: z.literal("openai"),
10
+ url: z.string().url().default("https://api.openai.com/v1/embeddings"),
11
+ apiKey: z.string().default(OPENAI_API_KEY),
12
+ });
13
+
14
+ export async function embeddingEndpointOpenAI(
15
+ input: z.input<typeof embeddingEndpointOpenAIParametersSchema>
16
+ ): Promise<EmbeddingEndpoint> {
17
+ const { url, model, apiKey } = embeddingEndpointOpenAIParametersSchema.parse(input);
18
+
19
+ const maxBatchSize = model.maxBatchSize || 100;
20
+
21
+ return async ({ inputs }) => {
22
+ const requestURL = new URL(url);
23
+
24
+ const batchesInputs = chunk(inputs, maxBatchSize);
25
+
26
+ const batchesResults = await Promise.all(
27
+ batchesInputs.map(async (batchInputs) => {
28
+ const response = await fetch(requestURL, {
29
+ method: "POST",
30
+ headers: {
31
+ Accept: "application/json",
32
+ "Content-Type": "application/json",
33
+ ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}),
34
+ },
35
+ body: JSON.stringify({ input: batchInputs, model: model.name }),
36
+ });
37
+
38
+ const embeddings: Embedding[] = [];
39
+ const responseObject = await response.json();
40
+ for (const embeddingObject of responseObject.data) {
41
+ embeddings.push(embeddingObject.embedding);
42
+ }
43
+ return embeddings;
44
+ })
45
+ );
46
+
47
+ const flatAllEmbeddings = batchesResults.flat();
48
+
49
+ return flatAllEmbeddings;
50
+ };
51
+ }
src/lib/server/embeddingModels.ts CHANGED
@@ -22,6 +22,7 @@ const modelConfig = z.object({
22
  modelUrl: z.string().url().optional(),
23
  endpoints: z.array(embeddingEndpointSchema).nonempty(),
24
  chunkCharLength: z.number().positive(),
 
25
  preQuery: z.string().default(""),
26
  prePassage: z.string().default(""),
27
  });
@@ -70,6 +71,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
70
  return embeddingEndpoints.tei(args);
71
  case "transformersjs":
72
  return embeddingEndpoints.transformersjs(args);
 
 
73
  }
74
  }
75
 
 
22
  modelUrl: z.string().url().optional(),
23
  endpoints: z.array(embeddingEndpointSchema).nonempty(),
24
  chunkCharLength: z.number().positive(),
25
+ maxBatchSize: z.number().positive().optional(),
26
  preQuery: z.string().default(""),
27
  prePassage: z.string().default(""),
28
  });
 
71
  return embeddingEndpoints.tei(args);
72
  case "transformersjs":
73
  return embeddingEndpoints.transformersjs(args);
74
+ case "openai":
75
+ return embeddingEndpoints.openai(args);
76
  }
77
  }
78