Spaces:
Running
Running
File size: 3,813 Bytes
2ce3b4b 3ebd3a8 2ce3b4b 3ebd3a8 2ce3b4b 3ebd3a8 2ce3b4b 3ebd3a8 2ce3b4b 3ebd3a8 2ce3b4b 3ebd3a8 2ce3b4b 3ebd3a8 2ce3b4b 3ebd3a8 2ce3b4b 3ebd3a8 2ce3b4b 3ebd3a8 2ce3b4b 3ebd3a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import {
VertexAI,
HarmCategory,
HarmBlockThreshold,
type Content,
type TextPart,
} from "@google-cloud/vertexai";
import type { Endpoint } from "../endpoints";
import { z } from "zod";
import type { Message } from "$lib/types/Message";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
export const endpointVertexParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(), // allow optional and validate against emptiness
type: z.literal("vertex"),
location: z.string().default("europe-west1"),
project: z.string(),
apiEndpoint: z.string().optional(),
safetyThreshold: z
.enum([
HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmBlockThreshold.BLOCK_NONE,
HarmBlockThreshold.BLOCK_ONLY_HIGH,
])
.optional(),
});
export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
const { project, location, model, apiEndpoint, safetyThreshold } =
endpointVertexParametersSchema.parse(input);
const vertex_ai = new VertexAI({
project,
location,
apiEndpoint,
});
return async ({ messages, preprompt, generateSettings }) => {
const generativeModel = vertex_ai.getGenerativeModel({
model: model.id ?? model.name,
safetySettings: safetyThreshold
? [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: safetyThreshold,
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: safetyThreshold,
},
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: safetyThreshold,
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: safetyThreshold,
},
{
category: HarmCategory.HARM_CATEGORY_UNSPECIFIED,
threshold: safetyThreshold,
},
]
: undefined,
generationConfig: {
maxOutputTokens: generateSettings?.max_new_tokens ?? 4096,
stopSequences: generateSettings?.stop,
temperature: generateSettings?.temperature ?? 1,
},
});
// Preprompt is the same as the first system message.
let systemMessage = preprompt;
if (messages[0].from === "system") {
systemMessage = messages[0].content;
messages.shift();
}
const vertexMessages = messages.map(({ from, content }: Omit<Message, "id">): Content => {
return {
role: from === "user" ? "user" : "model",
parts: [
{
text: content,
},
],
};
});
const result = await generativeModel.generateContentStream({
contents: vertexMessages,
systemInstruction: systemMessage
? {
role: "system",
parts: [
{
text: systemMessage,
},
],
}
: undefined,
});
let tokenId = 0;
return (async function* () {
let generatedText = "";
for await (const data of result.stream) {
if (!data?.candidates?.length) break; // Handle case where no candidates are present
const candidate = data.candidates[0];
if (!candidate.content?.parts?.length) continue; // Skip if no parts are present
const firstPart = candidate.content.parts.find((part) => "text" in part) as
| TextPart
| undefined;
if (!firstPart) continue; // Skip if no text part is found
const isLastChunk = !!candidate.finishReason;
const content = firstPart.text;
generatedText += content;
const output: TextGenerationStreamOutput = {
token: {
id: tokenId++,
text: content,
logprob: 0,
special: isLastChunk,
},
generated_text: isLastChunk ? generatedText : null,
details: null,
};
yield output;
if (isLastChunk) break;
}
})();
};
}
export default endpointVertex;
|