Add Sagemaker support (#401)
Browse files* work on sagemaker support
* fix sagemaker integration
* remove unnecessary deps
* fix default endpoint
* remove unneeded deps, fixed types
* Use conditional validation for endpoints
This was needed because the discriminated union couldn't handle the legacy case where `host` is undefined.
* add note in readme about aws sagemaker
- README.md +18 -0
- package-lock.json +6 -0
- package.json +1 -0
- src/lib/server/generateFromDefaultEndpoint.ts +69 -17
- src/lib/server/modelEndpoint.ts +3 -6
- src/lib/server/models.ts +34 -9
- src/routes/conversation/[id]/+server.ts +35 -10
README.md
CHANGED
@@ -198,6 +198,24 @@ You can then add the generated information and the `authorization` parameter to
|
|
198 |
|
199 |
```
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
#### Client Certificate Authentication (mTLS)
|
202 |
|
203 |
Custom endpoints may require client certificate authentication, depending on how you configure them. To enable mTLS between Chat UI and your custom endpoint, you will need to set the `USE_CLIENT_CERTIFICATE` to `true`, and add the `CERT_PATH` and `KEY_PATH` parameters to your `.env.local`. These parameters should point to the location of the certificate and key files on your local machine. The certificate and key files should be in PEM format. The key file can be encrypted with a passphrase, in which case you will also need to add the `CLIENT_KEY_PASSWORD` parameter to your `.env.local`.
|
|
|
198 |
|
199 |
```
|
200 |
|
201 |
+
### Amazon SageMaker
|
202 |
+
|
203 |
+
You can also specify your Amazon SageMaker instance as an endpoint for chat-ui. The config goes like this:
|
204 |
+
|
205 |
+
```
|
206 |
+
"endpoints": [
|
207 |
+
{
|
208 |
+
"host" : "sagemaker",
|
209 |
+
"url": "", // your aws sagemaker url here
|
210 |
+
"accessKey": "",
|
211 |
+
"secretKey" : "",
|
212 |
+
"sessionToken": "", // optional
|
213 |
+
"weight": 1
|
214 |
+
}
|
215 |
+
```
|
216 |
+
|
217 |
+
You can get the `accessKey` and `secretKey` from your AWS user, under programmatic access.
|
218 |
+
|
219 |
#### Client Certificate Authentication (mTLS)
|
220 |
|
221 |
Custom endpoints may require client certificate authentication, depending on how you configure them. To enable mTLS between Chat UI and your custom endpoint, you will need to set the `USE_CLIENT_CERTIFICATE` to `true`, and add the `CERT_PATH` and `KEY_PATH` parameters to your `.env.local`. These parameters should point to the location of the certificate and key files on your local machine. The certificate and key files should be in PEM format. The key file can be encrypted with a passphrase, in which case you will also need to add the `CLIENT_KEY_PASSWORD` parameter to your `.env.local`.
|
package-lock.json
CHANGED
@@ -11,6 +11,7 @@
|
|
11 |
"@huggingface/hub": "^0.5.1",
|
12 |
"@huggingface/inference": "^2.2.0",
|
13 |
"autoprefixer": "^10.4.14",
|
|
|
14 |
"date-fns": "^2.29.3",
|
15 |
"dotenv": "^16.0.3",
|
16 |
"highlight.js": "^11.7.0",
|
@@ -1465,6 +1466,11 @@
|
|
1465 |
"postcss": "^8.1.0"
|
1466 |
}
|
1467 |
},
|
|
|
|
|
|
|
|
|
|
|
1468 |
"node_modules/balanced-match": {
|
1469 |
"version": "1.0.2",
|
1470 |
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz",
|
|
|
11 |
"@huggingface/hub": "^0.5.1",
|
12 |
"@huggingface/inference": "^2.2.0",
|
13 |
"autoprefixer": "^10.4.14",
|
14 |
+
"aws4fetch": "^1.0.17",
|
15 |
"date-fns": "^2.29.3",
|
16 |
"dotenv": "^16.0.3",
|
17 |
"highlight.js": "^11.7.0",
|
|
|
1466 |
"postcss": "^8.1.0"
|
1467 |
}
|
1468 |
},
|
1469 |
+
"node_modules/aws4fetch": {
|
1470 |
+
"version": "1.0.17",
|
1471 |
+
"resolved": "https://registry.npmjs.org/aws4fetch/-/aws4fetch-1.0.17.tgz",
|
1472 |
+
"integrity": "sha512-4IbOvsxqxeOSxI4oA+8xEO8SzBMVlzbSTgGy/EF83rHnQ/aKtP6Sc6YV/k0oiW0mqrcxuThlbDosnvetGOuO+g=="
|
1473 |
+
},
|
1474 |
"node_modules/balanced-match": {
|
1475 |
"version": "1.0.2",
|
1476 |
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz",
|
package.json
CHANGED
@@ -43,6 +43,7 @@
|
|
43 |
"@huggingface/hub": "^0.5.1",
|
44 |
"@huggingface/inference": "^2.2.0",
|
45 |
"autoprefixer": "^10.4.14",
|
|
|
46 |
"date-fns": "^2.29.3",
|
47 |
"dotenv": "^16.0.3",
|
48 |
"highlight.js": "^11.7.0",
|
|
|
43 |
"@huggingface/hub": "^0.5.1",
|
44 |
"@huggingface/inference": "^2.2.0",
|
45 |
"autoprefixer": "^10.4.14",
|
46 |
+
"aws4fetch": "^1.0.17",
|
47 |
"date-fns": "^2.29.3",
|
48 |
"dotenv": "^16.0.3",
|
49 |
"highlight.js": "^11.7.0",
|
src/lib/server/generateFromDefaultEndpoint.ts
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import { defaultModel } from "$lib/server/models";
|
2 |
import { modelEndpoint } from "./modelEndpoint";
|
3 |
-
import { textGeneration } from "@huggingface/inference";
|
4 |
import { trimSuffix } from "$lib/utils/trimSuffix";
|
5 |
import { trimPrefix } from "$lib/utils/trimPrefix";
|
6 |
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
|
|
|
7 |
|
8 |
interface Parameters {
|
9 |
temperature: number;
|
@@ -21,24 +21,76 @@ export async function generateFromDefaultEndpoint(
|
|
21 |
return_full_text: false,
|
22 |
};
|
23 |
|
24 |
-
const
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
28 |
inputs: prompt,
|
29 |
-
|
30 |
-
|
31 |
-
{
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
generated_text = trimSuffix(
|
41 |
-
trimPrefix(generated_text, "<|startoftext|>"),
|
42 |
PUBLIC_SEP_TOKEN
|
43 |
).trimEnd();
|
44 |
|
|
|
1 |
import { defaultModel } from "$lib/server/models";
|
2 |
import { modelEndpoint } from "./modelEndpoint";
|
|
|
3 |
import { trimSuffix } from "$lib/utils/trimSuffix";
|
4 |
import { trimPrefix } from "$lib/utils/trimPrefix";
|
5 |
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
|
6 |
+
import { AwsClient } from "aws4fetch";
|
7 |
|
8 |
interface Parameters {
|
9 |
temperature: number;
|
|
|
21 |
return_full_text: false,
|
22 |
};
|
23 |
|
24 |
+
const randomEndpoint = modelEndpoint(defaultModel);
|
25 |
+
|
26 |
+
const abortController = new AbortController();
|
27 |
+
|
28 |
+
let resp: Response;
|
29 |
+
|
30 |
+
if (randomEndpoint.host === "sagemaker") {
|
31 |
+
const requestParams = JSON.stringify({
|
32 |
+
...newParameters,
|
33 |
inputs: prompt,
|
34 |
+
});
|
35 |
+
|
36 |
+
const aws = new AwsClient({
|
37 |
+
accessKeyId: randomEndpoint.accessKey,
|
38 |
+
secretAccessKey: randomEndpoint.secretKey,
|
39 |
+
sessionToken: randomEndpoint.sessionToken,
|
40 |
+
service: "sagemaker",
|
41 |
+
});
|
42 |
+
|
43 |
+
resp = await aws.fetch(randomEndpoint.url, {
|
44 |
+
method: "POST",
|
45 |
+
body: requestParams,
|
46 |
+
signal: abortController.signal,
|
47 |
+
headers: {
|
48 |
+
"Content-Type": "application/json",
|
49 |
+
},
|
50 |
+
});
|
51 |
+
} else {
|
52 |
+
resp = await fetch(randomEndpoint.url, {
|
53 |
+
headers: {
|
54 |
+
"Content-Type": "application/json",
|
55 |
+
Authorization: randomEndpoint.authorization,
|
56 |
+
},
|
57 |
+
method: "POST",
|
58 |
+
body: JSON.stringify({
|
59 |
+
...newParameters,
|
60 |
+
inputs: prompt,
|
61 |
+
}),
|
62 |
+
signal: abortController.signal,
|
63 |
+
});
|
64 |
+
}
|
65 |
+
|
66 |
+
if (!resp.ok) {
|
67 |
+
throw new Error(await resp.text());
|
68 |
+
}
|
69 |
+
|
70 |
+
if (!resp.body) {
|
71 |
+
throw new Error("Response body is empty");
|
72 |
+
}
|
73 |
+
|
74 |
+
const decoder = new TextDecoder();
|
75 |
+
const reader = resp.body.getReader();
|
76 |
+
|
77 |
+
let isDone = false;
|
78 |
+
let result = "";
|
79 |
+
|
80 |
+
while (!isDone) {
|
81 |
+
const { done, value } = await reader.read();
|
82 |
+
|
83 |
+
isDone = done;
|
84 |
+
result += decoder.decode(value, { stream: true }); // Convert current chunk to text
|
85 |
+
}
|
86 |
+
|
87 |
+
// Close the reader when done
|
88 |
+
reader.releaseLock();
|
89 |
+
|
90 |
+
const results = await JSON.parse(result);
|
91 |
|
92 |
+
let generated_text = trimSuffix(
|
93 |
+
trimPrefix(trimPrefix(results[0].generated_text, "<|startoftext|>"), prompt),
|
94 |
PUBLIC_SEP_TOKEN
|
95 |
).trimEnd();
|
96 |
|
src/lib/server/modelEndpoint.ts
CHANGED
@@ -9,7 +9,7 @@ import {
|
|
9 |
REJECT_UNAUTHORIZED,
|
10 |
} from "$env/static/private";
|
11 |
import { sum } from "$lib/utils/sum";
|
12 |
-
import type { BackendModel } from "./models";
|
13 |
|
14 |
import { loadClientCertificates } from "$lib/utils/loadClientCerts";
|
15 |
|
@@ -26,13 +26,10 @@ if (USE_CLIENT_CERTIFICATE === "true") {
|
|
26 |
/**
|
27 |
* Find a random load-balanced endpoint
|
28 |
*/
|
29 |
-
export function modelEndpoint(model: BackendModel): {
|
30 |
-
url: string;
|
31 |
-
authorization: string;
|
32 |
-
weight: number;
|
33 |
-
} {
|
34 |
if (!model.endpoints) {
|
35 |
return {
|
|
|
36 |
url: `${HF_API_ROOT}/${model.name}`,
|
37 |
authorization: `Bearer ${HF_ACCESS_TOKEN}`,
|
38 |
weight: 1,
|
|
|
9 |
REJECT_UNAUTHORIZED,
|
10 |
} from "$env/static/private";
|
11 |
import { sum } from "$lib/utils/sum";
|
12 |
+
import type { BackendModel, Endpoint } from "./models";
|
13 |
|
14 |
import { loadClientCertificates } from "$lib/utils/loadClientCerts";
|
15 |
|
|
|
26 |
/**
|
27 |
* Find a random load-balanced endpoint
|
28 |
*/
|
29 |
+
export function modelEndpoint(model: BackendModel): Endpoint {
|
|
|
|
|
|
|
|
|
30 |
if (!model.endpoints) {
|
31 |
return {
|
32 |
+
host: "tgi",
|
33 |
url: `${HF_API_ROOT}/${model.name}`,
|
34 |
authorization: `Bearer ${HF_ACCESS_TOKEN}`,
|
35 |
weight: 1,
|
src/lib/server/models.ts
CHANGED
@@ -1,6 +1,38 @@
|
|
1 |
import { HF_ACCESS_TOKEN, MODELS, OLD_MODELS } from "$env/static/private";
|
2 |
import { z } from "zod";
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
const modelsRaw = z
|
5 |
.array(
|
6 |
z.object({
|
@@ -29,15 +61,7 @@ const modelsRaw = z
|
|
29 |
})
|
30 |
)
|
31 |
.optional(),
|
32 |
-
endpoints: z
|
33 |
-
.array(
|
34 |
-
z.object({
|
35 |
-
url: z.string().url(),
|
36 |
-
authorization: z.string().min(1).default(`Bearer ${HF_ACCESS_TOKEN}`),
|
37 |
-
weight: z.number().int().positive().default(1),
|
38 |
-
})
|
39 |
-
)
|
40 |
-
.optional(),
|
41 |
parameters: z
|
42 |
.object({
|
43 |
temperature: z.number().min(0).max(1),
|
@@ -77,6 +101,7 @@ export const oldModels = OLD_MODELS
|
|
77 |
: [];
|
78 |
|
79 |
export type BackendModel = (typeof models)[0];
|
|
|
80 |
|
81 |
export const defaultModel = models[0];
|
82 |
|
|
|
1 |
import { HF_ACCESS_TOKEN, MODELS, OLD_MODELS } from "$env/static/private";
|
2 |
import { z } from "zod";
|
3 |
|
4 |
+
const sagemakerEndpoint = z.object({
|
5 |
+
host: z.literal("sagemaker"),
|
6 |
+
url: z.string().url(),
|
7 |
+
accessKey: z.string().min(1),
|
8 |
+
secretKey: z.string().min(1),
|
9 |
+
sessionToken: z.string().optional(),
|
10 |
+
});
|
11 |
+
|
12 |
+
const tgiEndpoint = z.object({
|
13 |
+
host: z.union([z.literal("tgi"), z.undefined()]),
|
14 |
+
url: z.string().url(),
|
15 |
+
authorization: z.string().min(1).default(`Bearer ${HF_ACCESS_TOKEN}`),
|
16 |
+
});
|
17 |
+
|
18 |
+
const commonEndpoint = z.object({
|
19 |
+
weight: z.number().int().positive().default(1),
|
20 |
+
});
|
21 |
+
|
22 |
+
const endpoint = z.lazy(() =>
|
23 |
+
z.union([sagemakerEndpoint.merge(commonEndpoint), tgiEndpoint.merge(commonEndpoint)])
|
24 |
+
);
|
25 |
+
|
26 |
+
const combinedEndpoint = endpoint.transform((data) => {
|
27 |
+
if (data.host === "tgi" || data.host === undefined) {
|
28 |
+
return tgiEndpoint.merge(commonEndpoint).parse(data);
|
29 |
+
} else if (data.host === "sagemaker") {
|
30 |
+
return sagemakerEndpoint.merge(commonEndpoint).parse(data);
|
31 |
+
} else {
|
32 |
+
throw new Error(`Invalid host: ${data.host}`);
|
33 |
+
}
|
34 |
+
});
|
35 |
+
|
36 |
const modelsRaw = z
|
37 |
.array(
|
38 |
z.object({
|
|
|
61 |
})
|
62 |
)
|
63 |
.optional(),
|
64 |
+
endpoints: z.array(combinedEndpoint).optional(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
parameters: z
|
66 |
.object({
|
67 |
temperature: z.number().min(0).max(1),
|
|
|
101 |
: [];
|
102 |
|
103 |
export type BackendModel = (typeof models)[0];
|
104 |
+
export type Endpoint = z.infer<typeof endpoint>;
|
105 |
|
106 |
export const defaultModel = models[0];
|
107 |
|
src/routes/conversation/[id]/+server.ts
CHANGED
@@ -16,6 +16,7 @@ import type { TextGenerationStreamOutput } from "@huggingface/inference";
|
|
16 |
import { error } from "@sveltejs/kit";
|
17 |
import { ObjectId } from "mongodb";
|
18 |
import { z } from "zod";
|
|
|
19 |
|
20 |
export async function POST({ request, fetch, locals, params }) {
|
21 |
const id = z.string().parse(params.id);
|
@@ -101,18 +102,42 @@ export async function POST({ request, fetch, locals, params }) {
|
|
101 |
|
102 |
const abortController = new AbortController();
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
Authorization: randomEndpoint.authorization,
|
108 |
-
},
|
109 |
-
method: "POST",
|
110 |
-
body: JSON.stringify({
|
111 |
...json,
|
112 |
inputs: prompt,
|
113 |
-
})
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
if (!resp.body) {
|
118 |
throw new Error("Response body is empty");
|
|
|
16 |
import { error } from "@sveltejs/kit";
|
17 |
import { ObjectId } from "mongodb";
|
18 |
import { z } from "zod";
|
19 |
+
import { AwsClient } from "aws4fetch";
|
20 |
|
21 |
export async function POST({ request, fetch, locals, params }) {
|
22 |
const id = z.string().parse(params.id);
|
|
|
102 |
|
103 |
const abortController = new AbortController();
|
104 |
|
105 |
+
let resp: Response;
|
106 |
+
if (randomEndpoint.host === "sagemaker") {
|
107 |
+
const requestParams = JSON.stringify({
|
|
|
|
|
|
|
|
|
108 |
...json,
|
109 |
inputs: prompt,
|
110 |
+
});
|
111 |
+
|
112 |
+
const aws = new AwsClient({
|
113 |
+
accessKeyId: randomEndpoint.accessKey,
|
114 |
+
secretAccessKey: randomEndpoint.secretKey,
|
115 |
+
sessionToken: randomEndpoint.sessionToken,
|
116 |
+
service: "sagemaker",
|
117 |
+
});
|
118 |
+
|
119 |
+
resp = await aws.fetch(randomEndpoint.url, {
|
120 |
+
method: "POST",
|
121 |
+
body: requestParams,
|
122 |
+
signal: abortController.signal,
|
123 |
+
headers: {
|
124 |
+
"Content-Type": "application/json",
|
125 |
+
},
|
126 |
+
});
|
127 |
+
} else {
|
128 |
+
resp = await fetch(randomEndpoint.url, {
|
129 |
+
headers: {
|
130 |
+
"Content-Type": request.headers.get("Content-Type") ?? "application/json",
|
131 |
+
Authorization: randomEndpoint.authorization,
|
132 |
+
},
|
133 |
+
method: "POST",
|
134 |
+
body: JSON.stringify({
|
135 |
+
...json,
|
136 |
+
inputs: prompt,
|
137 |
+
}),
|
138 |
+
signal: abortController.signal,
|
139 |
+
});
|
140 |
+
}
|
141 |
|
142 |
if (!resp.body) {
|
143 |
throw new Error("Response body is empty");
|