chat-ui / src /lib /server /endpoints /google /endpointVertex.ts
madppiper's picture
Google Vertex API support (#950)
3ebd3a8 unverified
raw
history blame contribute delete
No virus
2.17 kB
import { VertexAI, HarmCategory, HarmBlockThreshold } from "@google-cloud/vertexai";
import { buildPrompt } from "$lib/buildPrompt";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type { Endpoint } from "../endpoints";
import { z } from "zod";
export const endpointVertexParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(), // allow optional and validate against emptiness
type: z.literal("vertex"),
location: z.string().default("europe-west1"),
project: z.string(),
apiEndpoint: z.string().optional(),
});
export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
const { project, location, model, apiEndpoint } = endpointVertexParametersSchema.parse(input);
const vertex_ai = new VertexAI({
project,
location,
apiEndpoint,
});
const generativeModel = vertex_ai.getGenerativeModel({
model: model.id ?? model.name,
safety_settings: [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
},
],
generation_config: {},
});
return async ({ messages, preprompt, continueMessage }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
preprompt,
model,
});
const chat = generativeModel.startChat();
const result = await chat.sendMessageStream(prompt);
let tokenId = 0;
return (async function* () {
let generatedText = "";
for await (const data of result.stream) {
if (Array.isArray(data?.candidates) && data.candidates.length > 0) {
const firstPart = data.candidates[0].content.parts[0];
if ("text" in firstPart) {
const content = firstPart.text;
generatedText += content;
const output: TextGenerationStreamOutput = {
token: {
id: tokenId++,
text: content ?? "",
logprob: 0,
special: false,
},
generated_text: generatedText,
details: null,
};
yield output;
}
if (!data.candidates.slice(-1)[0].finishReason) break;
} else {
break;
}
}
})();
};
}
export default endpointVertex;