Galén nsarrazin HF staff commited on
Commit
a1afcb6
1 Parent(s): 2da78f5

Add support for passing an API key or any other custom token in the authorization header (#579)

Browse files

* Add support for passing an API key or any other custom token in the authorization header

* Make linter happy

* Fix README as per linter suggestions

* Refactor endpoints to actually parse zod config

* Remove top level env var and simplify header addition

* Skip section on API key or other, remove obsolete comment in endpointTgi.ts and remote CUSTOM_AUTHORIZATION_TOKEN from .env

---------

Co-authored-by: Nathan Sarrazin <sarrazin.nathan@gmail.com>

README.md CHANGED
@@ -397,6 +397,8 @@ You can then add the generated information and the `authorization` parameter to
397
  ]
398
  ```
399
 
 
 
400
  #### Models hosted on multiple custom endpoints
401
 
402
  If the model being hosted will be available on multiple servers/instances add the `weight` parameter to your `.env.local`. The `weight` will be used to determine the probability of requesting a particular endpoint.
 
397
  ]
398
  ```
399
 
400
+ Please note that if `HF_ACCESS_TOKEN` is also set or not empty, it will take precedence.
401
+
402
  #### Models hosted on multiple custom endpoints
403
 
404
  If the model being hosted will be available on multiple servers/instances add the `weight` parameter to your `.env.local`. The `weight` will be used to determine the probability of requesting a particular endpoint.
src/lib/server/endpoints/aws/endpointAws.ts CHANGED
@@ -15,15 +15,9 @@ export const endpointAwsParametersSchema = z.object({
15
  region: z.string().optional(),
16
  });
17
 
18
- export async function endpointAws({
19
- url,
20
- accessKey,
21
- secretKey,
22
- sessionToken,
23
- model,
24
- region,
25
- service,
26
- }: z.infer<typeof endpointAwsParametersSchema>): Promise<Endpoint> {
27
  let AwsClient;
28
  try {
29
  AwsClient = (await import("aws4fetch")).AwsClient;
@@ -31,6 +25,9 @@ export async function endpointAws({
31
  throw new Error("Failed to import aws4fetch");
32
  }
33
 
 
 
 
34
  const aws = new AwsClient({
35
  accessKeyId: accessKey,
36
  secretAccessKey: secretKey,
 
15
  region: z.string().optional(),
16
  });
17
 
18
+ export async function endpointAws(
19
+ input: z.input<typeof endpointAwsParametersSchema>
20
+ ): Promise<Endpoint> {
 
 
 
 
 
 
21
  let AwsClient;
22
  try {
23
  AwsClient = (await import("aws4fetch")).AwsClient;
 
25
  throw new Error("Failed to import aws4fetch");
26
  }
27
 
28
+ const { url, accessKey, secretKey, sessionToken, model, region, service } =
29
+ endpointAwsParametersSchema.parse(input);
30
+
31
  const aws = new AwsClient({
32
  accessKeyId: accessKey,
33
  secretAccessKey: secretKey,
src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts CHANGED
@@ -12,10 +12,10 @@ export const endpointLlamacppParametersSchema = z.object({
12
  accessToken: z.string().min(1).default(HF_ACCESS_TOKEN),
13
  });
14
 
15
- export function endpointLlamacpp({
16
- url,
17
- model,
18
- }: z.infer<typeof endpointLlamacppParametersSchema>): Endpoint {
19
  return async ({ conversation }) => {
20
  const prompt = await buildPrompt({
21
  messages: conversation.messages,
 
12
  accessToken: z.string().min(1).default(HF_ACCESS_TOKEN),
13
  });
14
 
15
+ export function endpointLlamacpp(
16
+ input: z.input<typeof endpointLlamacppParametersSchema>
17
+ ): Endpoint {
18
+ const { url, model } = endpointLlamacppParametersSchema.parse(input);
19
  return async ({ conversation }) => {
20
  const prompt = await buildPrompt({
21
  messages: conversation.messages,
src/lib/server/endpoints/ollama/endpointOllama.ts CHANGED
@@ -11,11 +11,9 @@ export const endpointOllamaParametersSchema = z.object({
11
  ollamaName: z.string().min(1).optional(),
12
  });
13
 
14
- export function endpointOllama({
15
- url,
16
- model,
17
- ollamaName,
18
- }: z.infer<typeof endpointOllamaParametersSchema>): Endpoint {
19
  return async ({ conversation }) => {
20
  const prompt = await buildPrompt({
21
  messages: conversation.messages,
 
11
  ollamaName: z.string().min(1).optional(),
12
  });
13
 
14
+ export function endpointOllama(input: z.input<typeof endpointOllamaParametersSchema>): Endpoint {
15
+ const { url, model, ollamaName } = endpointOllamaParametersSchema.parse(input);
16
+
 
 
17
  return async ({ conversation }) => {
18
  const prompt = await buildPrompt({
19
  messages: conversation.messages,
src/lib/server/endpoints/openai/endpointOai.ts CHANGED
@@ -16,12 +16,10 @@ export const endpointOAIParametersSchema = z.object({
16
  .default("chat_completions"),
17
  });
18
 
19
- export async function endpointOai({
20
- baseURL,
21
- apiKey,
22
- completion,
23
- model,
24
- }: z.infer<typeof endpointOAIParametersSchema>): Promise<Endpoint> {
25
  let OpenAI;
26
  try {
27
  OpenAI = (await import("openai")).OpenAI;
 
16
  .default("chat_completions"),
17
  });
18
 
19
+ export async function endpointOai(
20
+ input: z.input<typeof endpointOAIParametersSchema>
21
+ ): Promise<Endpoint> {
22
+ const { baseURL, apiKey, completion, model } = endpointOAIParametersSchema.parse(input);
 
 
23
  let OpenAI;
24
  try {
25
  OpenAI = (await import("openai")).OpenAI;
src/lib/server/endpoints/tgi/endpointTgi.ts CHANGED
@@ -10,13 +10,11 @@ export const endpointTgiParametersSchema = z.object({
10
  type: z.literal("tgi"),
11
  url: z.string().url(),
12
  accessToken: z.string().default(HF_ACCESS_TOKEN),
 
13
  });
14
 
15
- export function endpointTgi({
16
- url,
17
- accessToken,
18
- model,
19
- }: z.infer<typeof endpointTgiParametersSchema>): Endpoint {
20
  return async ({ conversation }) => {
21
  const prompt = await buildPrompt({
22
  messages: conversation.messages,
@@ -33,7 +31,19 @@ export function endpointTgi({
33
  inputs: prompt,
34
  accessToken,
35
  },
36
- { use_cache: false }
 
 
 
 
 
 
 
 
 
 
 
 
37
  );
38
  };
39
  }
 
10
  type: z.literal("tgi"),
11
  url: z.string().url(),
12
  accessToken: z.string().default(HF_ACCESS_TOKEN),
13
+ authorization: z.string().optional(),
14
  });
15
 
16
+ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
17
+ const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
 
 
 
18
  return async ({ conversation }) => {
19
  const prompt = await buildPrompt({
20
  messages: conversation.messages,
 
31
  inputs: prompt,
32
  accessToken,
33
  },
34
+ {
35
+ use_cache: false,
36
+ fetch: async (endpointUrl, info) => {
37
+ if (info && authorization && !accessToken) {
38
+ // Set authorization header if it is defined and HF_ACCESS_TOKEN is empty
39
+ info.headers = {
40
+ ...info.headers,
41
+ Authorization: authorization,
42
+ };
43
+ }
44
+ return fetch(endpointUrl, info);
45
+ },
46
+ }
47
  );
48
  };
49
  }