import { buildPrompt } from "$lib/buildPrompt"; import { textGenerationStream } from "@huggingface/inference"; import { z } from "zod"; import type { Endpoint } from "../endpoints"; export const endpointAwsParametersSchema = z.object({ weight: z.number().int().positive().default(1), model: z.any(), type: z.literal("aws"), url: z.string().url(), accessKey: z.string().min(1), secretKey: z.string().min(1), sessionToken: z.string().optional(), service: z.union([z.literal("sagemaker"), z.literal("lambda")]).default("sagemaker"), region: z.string().optional(), }); export async function endpointAws( input: z.input ): Promise { let AwsClient; try { AwsClient = (await import("aws4fetch")).AwsClient; } catch (e) { throw new Error("Failed to import aws4fetch"); } const { url, accessKey, secretKey, sessionToken, model, region, service } = endpointAwsParametersSchema.parse(input); const aws = new AwsClient({ accessKeyId: accessKey, secretAccessKey: secretKey, sessionToken, service, region, }); return async ({ messages, preprompt, continueMessage, generateSettings }) => { const prompt = await buildPrompt({ messages, continueMessage, preprompt, model, }); return textGenerationStream( { parameters: { ...model.parameters, ...generateSettings, return_full_text: false }, model: url, inputs: prompt, }, { use_cache: false, fetch: aws.fetch.bind(aws) as typeof fetch, } ); }; } export default endpointAws;