File size: 1,591 Bytes
9db8ced
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1afcb6
 
 
9db8ced
 
 
 
 
 
 
a1afcb6
 
 
9db8ced
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import { buildPrompt } from "$lib/buildPrompt";
import { textGenerationStream } from "@huggingface/inference";
import { z } from "zod";
import type { Endpoint } from "../endpoints";

export const endpointAwsParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("aws"),
	url: z.string().url(),
	accessKey: z.string().min(1),
	secretKey: z.string().min(1),
	sessionToken: z.string().optional(),
	service: z.union([z.literal("sagemaker"), z.literal("lambda")]).default("sagemaker"),
	region: z.string().optional(),
});

export async function endpointAws(
	input: z.input<typeof endpointAwsParametersSchema>
): Promise<Endpoint> {
	let AwsClient;
	try {
		AwsClient = (await import("aws4fetch")).AwsClient;
	} catch (e) {
		throw new Error("Failed to import aws4fetch");
	}

	const { url, accessKey, secretKey, sessionToken, model, region, service } =
		endpointAwsParametersSchema.parse(input);

	const aws = new AwsClient({
		accessKeyId: accessKey,
		secretAccessKey: secretKey,
		sessionToken,
		service,
		region,
	});

	return async ({ conversation }) => {
		const prompt = await buildPrompt({
			messages: conversation.messages,
			webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
			preprompt: conversation.preprompt,
			model,
		});

		return textGenerationStream(
			{
				parameters: { ...model.parameters, return_full_text: false },
				model: url,
				inputs: prompt,
			},
			{
				use_cache: false,
				fetch: aws.fetch.bind(aws) as typeof fetch,
			}
		);
	};
}

export default endpointAws;