File size: 2,170 Bytes
3ebd3a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import { VertexAI, HarmCategory, HarmBlockThreshold } from "@google-cloud/vertexai";
import { buildPrompt } from "$lib/buildPrompt";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type { Endpoint } from "../endpoints";
import { z } from "zod";

export const endpointVertexParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(), // allow optional and validate against emptiness
	type: z.literal("vertex"),
	location: z.string().default("europe-west1"),
	project: z.string(),
	apiEndpoint: z.string().optional(),
});

export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
	const { project, location, model, apiEndpoint } = endpointVertexParametersSchema.parse(input);

	const vertex_ai = new VertexAI({
		project,
		location,
		apiEndpoint,
	});

	const generativeModel = vertex_ai.getGenerativeModel({
		model: model.id ?? model.name,
		safety_settings: [
			{
				category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
				threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
			},
		],
		generation_config: {},
	});

	return async ({ messages, preprompt, continueMessage }) => {
		const prompt = await buildPrompt({
			messages,
			continueMessage,
			preprompt,
			model,
		});

		const chat = generativeModel.startChat();
		const result = await chat.sendMessageStream(prompt);
		let tokenId = 0;

		return (async function* () {
			let generatedText = "";

			for await (const data of result.stream) {
				if (Array.isArray(data?.candidates) && data.candidates.length > 0) {
					const firstPart = data.candidates[0].content.parts[0];
					if ("text" in firstPart) {
						const content = firstPart.text;
						generatedText += content;
						const output: TextGenerationStreamOutput = {
							token: {
								id: tokenId++,
								text: content ?? "",
								logprob: 0,
								special: false,
							},
							generated_text: generatedText,
							details: null,
						};
						yield output;
					}

					if (!data.candidates.slice(-1)[0].finishReason) break;
				} else {
					break;
				}
			}
		})();
	};
}
export default endpointVertex;