import { z } from "zod"; import { env } from "$env/dynamic/private"; import type { Endpoint } from "../endpoints"; import type { TextGenerationStreamOutput } from "@huggingface/inference"; import type { Cohere, CohereClient } from "cohere-ai"; import { buildPrompt } from "$lib/buildPrompt"; export const endpointCohereParametersSchema = z.object({ weight: z.number().int().positive().default(1), model: z.any(), type: z.literal("cohere"), apiKey: z.string().default(env.COHERE_API_TOKEN), raw: z.boolean().default(false), }); export async function endpointCohere( input: z.input ): Promise { const { apiKey, model, raw } = endpointCohereParametersSchema.parse(input); let cohere: CohereClient; try { cohere = new (await import("cohere-ai")).CohereClient({ token: apiKey, }); } catch (e) { throw new Error("Failed to import cohere-ai", { cause: e }); } return async ({ messages, preprompt, generateSettings, continueMessage }) => { let system = preprompt; if (messages?.[0]?.from === "system") { system = messages[0].content; } const parameters = { ...model.parameters, ...generateSettings }; return (async function* () { let stream; let tokenId = 0; if (raw) { const prompt = await buildPrompt({ messages: messages.filter((message) => message.from !== "system"), model, preprompt: system, continueMessage, }); stream = await cohere.chatStream({ message: prompt, rawPrompting: true, model: model.id ?? model.name, p: parameters?.top_p, k: parameters?.top_k, maxTokens: parameters?.max_new_tokens, temperature: parameters?.temperature, stopSequences: parameters?.stop, frequencyPenalty: parameters?.frequency_penalty, }); } else { const formattedMessages = messages .filter((message) => message.from !== "system") .map((message) => ({ role: message.from === "user" ? "USER" : "CHATBOT", message: message.content, })) satisfies Cohere.ChatMessage[]; stream = await cohere.chatStream({ model: model.id ?? model.name, chatHistory: formattedMessages.slice(0, -1), message: formattedMessages[formattedMessages.length - 1].message, preamble: system, p: parameters?.top_p, k: parameters?.top_k, maxTokens: parameters?.max_new_tokens, temperature: parameters?.temperature, stopSequences: parameters?.stop, frequencyPenalty: parameters?.frequency_penalty, }); } for await (const output of stream) { if (output.eventType === "text-generation") { yield { token: { id: tokenId++, text: output.text, logprob: 0, special: false, }, generated_text: null, details: null, } satisfies TextGenerationStreamOutput; } else if (output.eventType === "stream-end") { if (["ERROR", "ERROR_TOXIC", "ERROR_LIMIT"].includes(output.finishReason)) { throw new Error(output.finishReason); } yield { token: { id: tokenId++, text: "", logprob: 0, special: true, }, generated_text: output.response.text, details: null, }; } } })(); }; }