File size: 1,456 Bytes
3cbea34
9db8ced
 
 
 
 
 
 
 
 
 
3cbea34
a1afcb6
9db8ced
 
a1afcb6
 
77399ca
d4016bc
e6addfc
 
 
9db8ced
e6addfc
9db8ced
 
dd4c436
 
d4016bc
dd4c436
 
 
 
a1afcb6
 
 
 
3cbea34
a1afcb6
 
 
 
 
 
 
 
dd4c436
9db8ced
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import { HF_ACCESS_TOKEN, HF_TOKEN } from "$env/static/private";
import { buildPrompt } from "$lib/buildPrompt";
import { textGenerationStream } from "@huggingface/inference";
import type { Endpoint } from "../endpoints";
import { z } from "zod";

export const endpointTgiParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("tgi"),
	url: z.string().url(),
	accessToken: z.string().default(HF_TOKEN ?? HF_ACCESS_TOKEN),
	authorization: z.string().optional(),
});

export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
	const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);

	return async ({ messages, preprompt, continueMessage, generateSettings }) => {
		const prompt = await buildPrompt({
			messages,
			preprompt,
			model,
			continueMessage,
		});

		return textGenerationStream(
			{
				parameters: { ...model.parameters, ...generateSettings, return_full_text: false },
				model: url,
				inputs: prompt,
				accessToken,
			},
			{
				use_cache: false,
				fetch: async (endpointUrl, info) => {
					if (info && authorization && !accessToken) {
						// Set authorization header if it is defined and HF_TOKEN is empty
						info.headers = {
							...info.headers,
							Authorization: authorization,
						};
					}
					return fetch(endpointUrl, info);
				},
			}
		);
	};
}

export default endpointTgi;