chat-ui / src /lib /server /endpoints /google /endpointVertex.ts
Luke
Support Gemini 1.5 Pro from Vertex AI (#1041)
2ce3b4b unverified
import {
VertexAI,
HarmCategory,
HarmBlockThreshold,
type Content,
type TextPart,
} from "@google-cloud/vertexai";
import type { Endpoint } from "../endpoints";
import { z } from "zod";
import type { Message } from "$lib/types/Message";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
export const endpointVertexParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(), // allow optional and validate against emptiness
type: z.literal("vertex"),
location: z.string().default("europe-west1"),
project: z.string(),
apiEndpoint: z.string().optional(),
safetyThreshold: z
.enum([
HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmBlockThreshold.BLOCK_NONE,
HarmBlockThreshold.BLOCK_ONLY_HIGH,
])
.optional(),
});
export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
const { project, location, model, apiEndpoint, safetyThreshold } =
endpointVertexParametersSchema.parse(input);
const vertex_ai = new VertexAI({
project,
location,
apiEndpoint,
});
return async ({ messages, preprompt, generateSettings }) => {
const generativeModel = vertex_ai.getGenerativeModel({
model: model.id ?? model.name,
safetySettings: safetyThreshold
? [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: safetyThreshold,
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: safetyThreshold,
},
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: safetyThreshold,
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: safetyThreshold,
},
{
category: HarmCategory.HARM_CATEGORY_UNSPECIFIED,
threshold: safetyThreshold,
},
]
: undefined,
generationConfig: {
maxOutputTokens: generateSettings?.max_new_tokens ?? 4096,
stopSequences: generateSettings?.stop,
temperature: generateSettings?.temperature ?? 1,
},
});
// Preprompt is the same as the first system message.
let systemMessage = preprompt;
if (messages[0].from === "system") {
systemMessage = messages[0].content;
messages.shift();
}
const vertexMessages = messages.map(({ from, content }: Omit<Message, "id">): Content => {
return {
role: from === "user" ? "user" : "model",
parts: [
{
text: content,
},
],
};
});
const result = await generativeModel.generateContentStream({
contents: vertexMessages,
systemInstruction: systemMessage
? {
role: "system",
parts: [
{
text: systemMessage,
},
],
}
: undefined,
});
let tokenId = 0;
return (async function* () {
let generatedText = "";
for await (const data of result.stream) {
if (!data?.candidates?.length) break; // Handle case where no candidates are present
const candidate = data.candidates[0];
if (!candidate.content?.parts?.length) continue; // Skip if no parts are present
const firstPart = candidate.content.parts.find((part) => "text" in part) as
| TextPart
| undefined;
if (!firstPart) continue; // Skip if no text part is found
const isLastChunk = !!candidate.finishReason;
const content = firstPart.text;
generatedText += content;
const output: TextGenerationStreamOutput = {
token: {
id: tokenId++,
text: content,
logprob: 0,
special: isLastChunk,
},
generated_text: isLastChunk ? generatedText : null,
details: null,
};
yield output;
if (isLastChunk) break;
}
})();
};
}
export default endpointVertex;