nsarrazin HF staff commited on
Commit
04b77a5
1 Parent(s): 86bc2ea

Use inference API for embeddings in huggingchat prod (#1037)

Browse files

* Let the user use HF_TOKEN as auth bearer token in TEI endpoint

* Use inference API for embeddings in huggingchat prod

.env.template CHANGED
@@ -232,6 +232,18 @@ MODELS=`[
232
  }
233
  ]`
234
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  OLD_MODELS=`[
236
  {"name":"bigcode/starcoder"},
237
  {"name":"OpenAssistant/oasst-sft-6-llama-30b-xor"},
 
232
  }
233
  ]`
234
 
235
+ TEXT_EMBEDDING_MODELS = `[
236
+ {
237
+ "name": "BAAI/bge-small-en-v1.5",
238
+ "id": "BAAI/bge-small-en-v1.5",
239
+ "chunkCharLength": 2048,
240
+ "endpoints": [
241
+ { "type": "hfapi" }
242
+ ]
243
+ }
244
+ ]`
245
+
246
+
247
  OLD_MODELS=`[
248
  {"name":"bigcode/starcoder"},
249
  {"name":"OpenAssistant/oasst-sft-6-llama-30b-xor"},
src/lib/server/embeddingEndpoints/embeddingEndpoints.ts CHANGED
@@ -11,6 +11,7 @@ import {
11
  embeddingEndpointOpenAI,
12
  embeddingEndpointOpenAIParametersSchema,
13
  } from "./openai/embeddingEndpoints";
 
14
 
15
  // parameters passed when generating text
16
  interface EmbeddingEndpointParameters {
@@ -26,6 +27,7 @@ export const embeddingEndpointSchema = z.discriminatedUnion("type", [
26
  embeddingEndpointTeiParametersSchema,
27
  embeddingEndpointTransformersJSParametersSchema,
28
  embeddingEndpointOpenAIParametersSchema,
 
29
  ]);
30
 
31
  type EmbeddingEndpointTypeOptions = z.infer<typeof embeddingEndpointSchema>["type"];
@@ -42,6 +44,7 @@ export const embeddingEndpoints: {
42
  tei: embeddingEndpointTei,
43
  transformersjs: embeddingEndpointTransformersJS,
44
  openai: embeddingEndpointOpenAI,
 
45
  };
46
 
47
  export default embeddingEndpoints;
 
11
  embeddingEndpointOpenAI,
12
  embeddingEndpointOpenAIParametersSchema,
13
  } from "./openai/embeddingEndpoints";
14
+ import { embeddingEndpointHfApi, embeddingEndpointHfApiSchema } from "./hfApi/embeddingHfApi";
15
 
16
  // parameters passed when generating text
17
  interface EmbeddingEndpointParameters {
 
27
  embeddingEndpointTeiParametersSchema,
28
  embeddingEndpointTransformersJSParametersSchema,
29
  embeddingEndpointOpenAIParametersSchema,
30
+ embeddingEndpointHfApiSchema,
31
  ]);
32
 
33
  type EmbeddingEndpointTypeOptions = z.infer<typeof embeddingEndpointSchema>["type"];
 
44
  tei: embeddingEndpointTei,
45
  transformersjs: embeddingEndpointTransformersJS,
46
  openai: embeddingEndpointOpenAI,
47
+ hfapi: embeddingEndpointHfApi,
48
  };
49
 
50
  export default embeddingEndpoints;
src/lib/server/embeddingEndpoints/hfApi/embeddingHfApi.ts ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { z } from "zod";
2
+ import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
3
+ import { chunk } from "$lib/utils/chunk";
4
+ import { HF_TOKEN } from "$env/static/private";
5
+
6
+ export const embeddingEndpointHfApiSchema = z.object({
7
+ weight: z.number().int().positive().default(1),
8
+ model: z.any(),
9
+ type: z.literal("hfapi"),
10
+ authorization: z
11
+ .string()
12
+ .optional()
13
+ .transform((v) => (!v && HF_TOKEN ? "Bearer " + HF_TOKEN : v)), // if the header is not set but HF_TOKEN is, use it as the authorization header
14
+ });
15
+
16
+ export async function embeddingEndpointHfApi(
17
+ input: z.input<typeof embeddingEndpointHfApiSchema>
18
+ ): Promise<EmbeddingEndpoint> {
19
+ const { model, authorization } = embeddingEndpointHfApiSchema.parse(input);
20
+ const url = "https://api-inference.huggingface.co/models/" + model.id;
21
+
22
+ return async ({ inputs }) => {
23
+ const batchesInputs = chunk(inputs, 128);
24
+
25
+ const batchesResults = await Promise.all(
26
+ batchesInputs.map(async (batchInputs) => {
27
+ const response = await fetch(`${url}`, {
28
+ method: "POST",
29
+ headers: {
30
+ Accept: "application/json",
31
+ "Content-Type": "application/json",
32
+ ...(authorization ? { Authorization: authorization } : {}),
33
+ },
34
+ body: JSON.stringify({ inputs: batchInputs }),
35
+ });
36
+
37
+ if (!response.ok) {
38
+ console.log(await response.text());
39
+ console.error("Failed to get embeddings from Hugging Face API", response);
40
+ }
41
+
42
+ const embeddings: Embedding[] = await response.json();
43
+ return embeddings;
44
+ })
45
+ );
46
+
47
+ const flatAllEmbeddings = batchesResults.flat();
48
+
49
+ return flatAllEmbeddings;
50
+ };
51
+ }
src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts CHANGED
@@ -1,13 +1,17 @@
1
  import { z } from "zod";
2
  import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
3
  import { chunk } from "$lib/utils/chunk";
 
4
 
5
  export const embeddingEndpointTeiParametersSchema = z.object({
6
  weight: z.number().int().positive().default(1),
7
  model: z.any(),
8
  type: z.literal("tei"),
9
  url: z.string().url(),
10
- authorization: z.string().optional(),
 
 
 
11
  });
12
 
13
  const getModelInfoByUrl = async (url: string, authorization?: string) => {
 
1
  import { z } from "zod";
2
  import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
3
  import { chunk } from "$lib/utils/chunk";
4
+ import { HF_TOKEN } from "$env/static/private";
5
 
6
  export const embeddingEndpointTeiParametersSchema = z.object({
7
  weight: z.number().int().positive().default(1),
8
  model: z.any(),
9
  type: z.literal("tei"),
10
  url: z.string().url(),
11
+ authorization: z
12
+ .string()
13
+ .optional()
14
+ .transform((v) => (!v && HF_TOKEN ? "Bearer " + HF_TOKEN : v)), // if the header is not set but HF_TOKEN is, use it as the authorization header
15
  });
16
 
17
  const getModelInfoByUrl = async (url: string, authorization?: string) => {
src/lib/server/embeddingModels.ts CHANGED
@@ -73,6 +73,10 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
73
  return embeddingEndpoints.transformersjs(args);
74
  case "openai":
75
  return embeddingEndpoints.openai(args);
 
 
 
 
76
  }
77
  }
78
 
 
73
  return embeddingEndpoints.transformersjs(args);
74
  case "openai":
75
  return embeddingEndpoints.openai(args);
76
+ case "hfapi":
77
+ return embeddingEndpoints.hfapi(args);
78
+ default:
79
+ throw new Error(`Unknown endpoint type: ${args}`);
80
  }
81
  }
82