coyotte508 HF staff commited on
Commit
9be5ab5
1 Parent(s): c2e468f

⚡️ Load balance endpoints (#106)

Browse files
.env CHANGED
@@ -3,15 +3,17 @@
3
 
4
  MONGODB_URL=#your mongodb URL here
5
  MONGODB_DB_NAME=chat-ui
6
- HF_TOKEN=#your huggingface token here
7
  COOKIE_NAME=hf-chat
8
 
9
  PUBLIC_MAX_INPUT_TOKENS=1024
10
  PUBLIC_ORIGIN=#https://hf.co
11
- PUBLIC_MODEL_ENDPOINT=https://api-inference.huggingface.co/models/OpenAssistant/oasst-sft-6-llama-30b
12
  PUBLIC_MODEL_NAME=OpenAssistant/oasst-sft-6-llama-30b # public facing link
13
  PUBLIC_MODEL_ID=OpenAssistant/oasst-sft-6-llama-30b-xor # used to link to model page
14
  PUBLIC_DISABLE_INTRO_TILES=false
15
  PUBLIC_USER_MESSAGE_TOKEN=<|prompter|>
16
  PUBLIC_ASSISTANT_MESSAGE_TOKEN=<|assistant|>
17
  PUBLIC_SEP_TOKEN=<|endoftext|>
 
 
 
 
 
3
 
4
  MONGODB_URL=#your mongodb URL here
5
  MONGODB_DB_NAME=chat-ui
 
6
  COOKIE_NAME=hf-chat
7
 
8
  PUBLIC_MAX_INPUT_TOKENS=1024
9
  PUBLIC_ORIGIN=#https://hf.co
 
10
  PUBLIC_MODEL_NAME=OpenAssistant/oasst-sft-6-llama-30b # public facing link
11
  PUBLIC_MODEL_ID=OpenAssistant/oasst-sft-6-llama-30b-xor # used to link to model page
12
  PUBLIC_DISABLE_INTRO_TILES=false
13
  PUBLIC_USER_MESSAGE_TOKEN=<|prompter|>
14
  PUBLIC_ASSISTANT_MESSAGE_TOKEN=<|assistant|>
15
  PUBLIC_SEP_TOKEN=<|endoftext|>
16
+
17
+ # Array<{endpoint: string, authorization: "Bearer XXX", weight: number}> to load balance
18
+ # Eg if one endpoint has weight 2 and the other has weight 1, the first endpoint will be called twice as often
19
+ MODEL_ENDPOINTS=`[]`
src/lib/server/modelEndpoint.ts ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { MODEL_ENDPOINTS } from "$env/static/private";
2
+ import { sum } from "$lib/utils/sum";
3
+
4
+ const endpoints: Array<{ endpoint: string; authorization: string; weight: number }> =
5
+ JSON.parse(MODEL_ENDPOINTS);
6
+ const totalWeight = sum(endpoints.map((e) => e.weight));
7
+
8
+ /**
9
+ * Find a random load-balanced endpoint
10
+ */
11
+ export function modelEndpoint(): { endpoint: string; authorization: string; weight: number } {
12
+ let random = Math.random() * totalWeight;
13
+ for (const endpoint of endpoints) {
14
+ if (random < endpoint.weight) {
15
+ return endpoint;
16
+ }
17
+ random -= endpoint.weight;
18
+ }
19
+
20
+ throw new Error("Invalid config, no endpoint found");
21
+ }
src/routes/conversation/[id]/+server.ts CHANGED
@@ -1,7 +1,7 @@
1
- import { HF_TOKEN } from "$env/static/private";
2
- import { PUBLIC_MODEL_ENDPOINT, PUBLIC_SEP_TOKEN } from "$env/static/public";
3
  import { buildPrompt } from "$lib/buildPrompt.js";
4
  import { collections } from "$lib/server/database.js";
 
5
  import type { Message } from "$lib/types/Message.js";
6
  import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
7
  import { sum } from "$lib/utils/sum";
@@ -29,10 +29,12 @@ export async function POST({ request, fetch, locals, params }) {
29
  const messages = [...conv.messages, { from: "user", content: json.inputs }] satisfies Message[];
30
  const prompt = buildPrompt(messages);
31
 
32
- const resp = await fetch(PUBLIC_MODEL_ENDPOINT, {
 
 
33
  headers: {
34
  "Content-Type": request.headers.get("Content-Type") ?? "application/json",
35
- Authorization: `Bearer ${HF_TOKEN}`,
36
  },
37
  method: "POST",
38
  body: JSON.stringify({
 
1
+ import { PUBLIC_SEP_TOKEN } from "$env/static/public";
 
2
  import { buildPrompt } from "$lib/buildPrompt.js";
3
  import { collections } from "$lib/server/database.js";
4
+ import { modelEndpoint } from "$lib/server/modelEndpoint.js";
5
  import type { Message } from "$lib/types/Message.js";
6
  import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
7
  import { sum } from "$lib/utils/sum";
 
29
  const messages = [...conv.messages, { from: "user", content: json.inputs }] satisfies Message[];
30
  const prompt = buildPrompt(messages);
31
 
32
+ const randomEndpoint = modelEndpoint();
33
+
34
+ const resp = await fetch(randomEndpoint.endpoint, {
35
  headers: {
36
  "Content-Type": request.headers.get("Content-Type") ?? "application/json",
37
+ Authorization: randomEndpoint.authorization,
38
  },
39
  method: "POST",
40
  body: JSON.stringify({
src/routes/conversation/[id]/summarize/+server.ts CHANGED
@@ -1,7 +1,7 @@
1
- import { HF_TOKEN } from "$env/static/private";
2
- import { PUBLIC_MAX_INPUT_TOKENS, PUBLIC_MODEL_ENDPOINT } from "$env/static/public";
3
  import { buildPrompt } from "$lib/buildPrompt";
4
  import { collections } from "$lib/server/database.js";
 
5
  import { textGeneration } from "@huggingface/inference";
6
  import { error } from "@sveltejs/kit";
7
  import { ObjectId } from "mongodb";
@@ -38,14 +38,20 @@ export async function POST({ params, locals, fetch }) {
38
  return_full_text: false,
39
  };
40
 
 
41
  const { generated_text } = await textGeneration(
42
  {
43
- model: PUBLIC_MODEL_ENDPOINT,
44
  inputs: prompt,
45
  parameters,
46
- accessToken: HF_TOKEN,
47
  },
48
- { fetch }
 
 
 
 
 
 
49
  );
50
 
51
  if (generated_text) {
 
1
+ import { PUBLIC_MAX_INPUT_TOKENS } from "$env/static/public";
 
2
  import { buildPrompt } from "$lib/buildPrompt";
3
  import { collections } from "$lib/server/database.js";
4
+ import { modelEndpoint } from "$lib/server/modelEndpoint.js";
5
  import { textGeneration } from "@huggingface/inference";
6
  import { error } from "@sveltejs/kit";
7
  import { ObjectId } from "mongodb";
 
38
  return_full_text: false,
39
  };
40
 
41
+ const endpoint = modelEndpoint();
42
  const { generated_text } = await textGeneration(
43
  {
44
+ model: endpoint.endpoint,
45
  inputs: prompt,
46
  parameters,
 
47
  },
48
+ {
49
+ fetch: (url, options) =>
50
+ fetch(url, {
51
+ ...options,
52
+ headers: { ...options?.headers, Authorization: endpoint.authorization },
53
+ }),
54
+ }
55
  );
56
 
57
  if (generated_text) {