chat-ui / src /lib /server /endpoints /cohere /endpointCohere.ts
nsarrazin's picture
nsarrazin HF staff
Move vars to dynamic, add metrics (#1085)
98b1c51 unverified
import { z } from "zod";
import { env } from "$env/dynamic/private";
import type { Endpoint } from "../endpoints";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type { Cohere, CohereClient } from "cohere-ai";
import { buildPrompt } from "$lib/buildPrompt";
export const endpointCohereParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("cohere"),
apiKey: z.string().default(env.COHERE_API_TOKEN),
raw: z.boolean().default(false),
});
export async function endpointCohere(
input: z.input<typeof endpointCohereParametersSchema>
): Promise<Endpoint> {
const { apiKey, model, raw } = endpointCohereParametersSchema.parse(input);
let cohere: CohereClient;
try {
cohere = new (await import("cohere-ai")).CohereClient({
token: apiKey,
});
} catch (e) {
throw new Error("Failed to import cohere-ai", { cause: e });
}
return async ({ messages, preprompt, generateSettings, continueMessage }) => {
let system = preprompt;
if (messages?.[0]?.from === "system") {
system = messages[0].content;
}
const parameters = { ...model.parameters, ...generateSettings };
return (async function* () {
let stream;
let tokenId = 0;
if (raw) {
const prompt = await buildPrompt({
messages: messages.filter((message) => message.from !== "system"),
model,
preprompt: system,
continueMessage,
});
stream = await cohere.chatStream({
message: prompt,
rawPrompting: true,
model: model.id ?? model.name,
p: parameters?.top_p,
k: parameters?.top_k,
maxTokens: parameters?.max_new_tokens,
temperature: parameters?.temperature,
stopSequences: parameters?.stop,
frequencyPenalty: parameters?.frequency_penalty,
});
} else {
const formattedMessages = messages
.filter((message) => message.from !== "system")
.map((message) => ({
role: message.from === "user" ? "USER" : "CHATBOT",
message: message.content,
})) satisfies Cohere.ChatMessage[];
stream = await cohere.chatStream({
model: model.id ?? model.name,
chatHistory: formattedMessages.slice(0, -1),
message: formattedMessages[formattedMessages.length - 1].message,
preamble: system,
p: parameters?.top_p,
k: parameters?.top_k,
maxTokens: parameters?.max_new_tokens,
temperature: parameters?.temperature,
stopSequences: parameters?.stop,
frequencyPenalty: parameters?.frequency_penalty,
});
}
for await (const output of stream) {
if (output.eventType === "text-generation") {
yield {
token: {
id: tokenId++,
text: output.text,
logprob: 0,
special: false,
},
generated_text: null,
details: null,
} satisfies TextGenerationStreamOutput;
} else if (output.eventType === "stream-end") {
if (["ERROR", "ERROR_TOXIC", "ERROR_LIMIT"].includes(output.finishReason)) {
throw new Error(output.finishReason);
}
yield {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: output.response.text,
details: null,
};
}
}
})();
};
}