Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| /** BUSINESS | |
| * | |
| * All utils that are bound to business logic | |
| * (and wouldn't be useful in another project) | |
| * should be here. | |
| * | |
| **/ | |
| import ctxLengthData from "$lib/data/context_length.json"; | |
| import { pricing } from "$lib/state/pricing.svelte.js"; | |
| import { snippets } from "@huggingface/inference"; | |
| import { ConversationClass, type ConversationEntityMembers } from "$lib/state/conversations.svelte"; | |
| import { token } from "$lib/state/token.svelte"; | |
| import { isMcpEnabled } from "$lib/constants.js"; | |
| import { | |
| isCustomModel, | |
| isHFModel, | |
| Provider, | |
| type Conversation, | |
| type ConversationMessage, | |
| type CustomModel, | |
| type Model, | |
| } from "$lib/types.js"; | |
| import { safeParse } from "$lib/utils/json.js"; | |
| import { omit } from "$lib/utils/object.svelte.js"; | |
| import type { ChatCompletionInputMessage, InferenceSnippet } from "@huggingface/tasks"; | |
| import { type ChatCompletionOutputMessage } from "@huggingface/tasks"; | |
| import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; | |
| import { images } from "$lib/state/images.svelte.js"; | |
| import { projects } from "$lib/state/projects.svelte.js"; | |
| import { mcpServers } from "$lib/state/mcps.svelte.js"; | |
| import { modifySnippet } from "$lib/utils/snippets.js"; | |
| import { models } from "$lib/state/models.svelte"; | |
| import { StreamReader } from "$lib/utils/stream.js"; | |
| type ChatCompletionInputMessageChunk = | |
| NonNullable<ChatCompletionInputMessage["content"]> extends string | (infer U)[] ? U : never; | |
| async function parseMessage(message: ConversationMessage): Promise<ChatCompletionInputMessage> { | |
| if (!message.images) return message; | |
| const urls = await Promise.all(message.images?.map(k => images.get(k)) ?? []); | |
| return { | |
| ...omit(message, "images"), | |
| content: [ | |
| { | |
| type: "text", | |
| text: message.content ?? "", | |
| }, | |
| ...message.images.map((_imgKey, i) => { | |
| return { | |
| type: "image_url", | |
| image_url: { url: urls[i] as string }, | |
| } satisfies ChatCompletionInputMessageChunk; | |
| }), | |
| ], | |
| }; | |
| } | |
| export function maxAllowedTokens(conversation: ConversationClass) { | |
| const model = conversation.model; | |
| const { provider } = conversation.data; | |
| if (!provider || !isHFModel(model)) { | |
| return customMaxTokens[conversation.model.id] ?? 100000; | |
| } | |
| // Try to get context length from pricing/router data first | |
| const ctxLength = pricing.getContextLength(model.id, provider); | |
| if (ctxLength) return ctxLength; | |
| // Fall back to local context length data if available | |
| const providerData = ctxLengthData[provider as keyof typeof ctxLengthData] as Record<string, number> | undefined; | |
| const localCtxLength = providerData?.[model.id]; | |
| if (localCtxLength) return localCtxLength; | |
| // Final fallback to custom max tokens | |
| return customMaxTokens[conversation.model.id] ?? 100000; | |
| } | |
| function getEnabledMCPs() { | |
| if (!isMcpEnabled()) return []; | |
| return mcpServers.enabled.map(server => ({ | |
| id: server.id, | |
| name: server.name, | |
| url: server.url, | |
| protocol: server.protocol, | |
| headers: server.headers, | |
| })); | |
| } | |
| function getResponseFormatObj(conversation: ConversationClass | Conversation) { | |
| const data = conversation instanceof ConversationClass ? conversation.data : conversation; | |
| const json = safeParse(data.structuredOutput?.schema ?? ""); | |
| if (json && data.structuredOutput?.enabled && models.supportsStructuredOutput(conversation.model, data.provider)) { | |
| switch (data.provider) { | |
| case "cohere": { | |
| return { | |
| type: "json_object", | |
| ...json, | |
| }; | |
| } | |
| case Provider.Cerebras: { | |
| return { | |
| type: "json_schema", | |
| json_schema: { ...json, name: "schema" }, | |
| }; | |
| } | |
| default: { | |
| return { | |
| type: "json_schema", | |
| json_schema: json, | |
| }; | |
| } | |
| } | |
| } | |
| } | |
| export async function handleStreamingResponse( | |
| conversation: ConversationClass | Conversation, | |
| onChunk: (content: string) => void, | |
| abortController: AbortController, | |
| ): Promise<void> { | |
| const data = conversation instanceof ConversationClass ? conversation.data : conversation; | |
| const model = conversation.model; | |
| const systemMessage = projects.current?.systemMessage; | |
| const messages: ConversationMessage[] = [ | |
| ...(isSystemPromptSupported(model) && systemMessage?.length ? [{ role: "system", content: systemMessage }] : []), | |
| ...(data.messages || []), | |
| ]; | |
| const parsed = await Promise.all(messages.map(parseMessage)); | |
| const requestBody = { | |
| model: { | |
| id: model.id, | |
| isCustom: isCustomModel(model), | |
| accessToken: isCustomModel(model) ? model.accessToken : undefined, | |
| endpointUrl: isCustomModel(model) ? model.endpointUrl : undefined, | |
| }, | |
| messages: parsed, | |
| config: data.config, | |
| provider: data.provider, | |
| streaming: true, | |
| response_format: getResponseFormatObj(conversation), | |
| accessToken: token.value, | |
| enabledMCPs: getEnabledMCPs(), | |
| }; | |
| const reader = await StreamReader.fromFetch("/api/generate", { | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| }, | |
| body: JSON.stringify(requestBody), | |
| signal: abortController.signal, | |
| }); | |
| let out = ""; | |
| for await (const chunk of reader.read()) { | |
| if (chunk.type === "chunk" && chunk.content) { | |
| out += chunk.content; | |
| onChunk(out); | |
| } else if (chunk.type === "error") { | |
| throw new Error(chunk.error || "Stream error"); | |
| } | |
| } | |
| } | |
| export async function handleNonStreamingResponse( | |
| conversation: ConversationClass | Conversation, | |
| ): Promise<{ message: ChatCompletionOutputMessage; completion_tokens: number }> { | |
| const data = conversation instanceof ConversationClass ? conversation.data : conversation; | |
| const model = conversation.model; | |
| const systemMessage = projects.current?.systemMessage; | |
| const messages: ConversationMessage[] = [ | |
| ...(isSystemPromptSupported(model) && systemMessage?.length ? [{ role: "system", content: systemMessage }] : []), | |
| ...(data.messages || []), | |
| ]; | |
| const parsed = await Promise.all(messages.map(parseMessage)); | |
| const requestBody = { | |
| model: { | |
| id: model.id, | |
| isCustom: isCustomModel(model), | |
| accessToken: isCustomModel(model) ? model.accessToken : undefined, | |
| endpointUrl: isCustomModel(model) ? model.endpointUrl : undefined, | |
| }, | |
| messages: parsed, | |
| config: data.config, | |
| provider: data.provider, | |
| streaming: false, | |
| response_format: getResponseFormatObj(conversation), | |
| accessToken: token.value, | |
| enabledMCPs: getEnabledMCPs(), | |
| }; | |
| const response = await fetch("/api/generate", { | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| }, | |
| body: JSON.stringify(requestBody), | |
| }); | |
| if (!response.ok) { | |
| const error = await response.json(); | |
| throw new Error(error.error || "Failed to generate response"); | |
| } | |
| return await response.json(); | |
| } | |
| export function isSystemPromptSupported(model: Model | CustomModel) { | |
| if (isCustomModel(model)) return true; // OpenAI-compatible models support system messages | |
| const template = model?.config.tokenizer_config?.chat_template; | |
| if (typeof template !== "string") return false; | |
| return template.includes("system"); | |
| } | |
| export const defaultSystemMessage: { [key: string]: string } = { | |
| "Qwen/QwQ-32B-Preview": | |
| "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", | |
| } as const; | |
| export const customMaxTokens: { [key: string]: number } = { | |
| "01-ai/Yi-1.5-34B-Chat": 2048, | |
| "HuggingFaceM4/idefics-9b-instruct": 2048, | |
| "deepseek-ai/DeepSeek-Coder-V2-Instruct": 16384, | |
| "bigcode/starcoder": 8192, | |
| "bigcode/starcoderplus": 8192, | |
| "HuggingFaceH4/starcoderbase-finetuned-oasst1": 8192, | |
| "google/gemma-7b": 8192, | |
| "google/gemma-1.1-7b-it": 8192, | |
| "google/gemma-2b": 8192, | |
| "google/gemma-1.1-2b-it": 8192, | |
| "google/gemma-2-27b-it": 8192, | |
| "google/gemma-2-9b-it": 4096, | |
| "google/gemma-2-2b-it": 8192, | |
| "tiiuae/falcon-7b": 8192, | |
| "tiiuae/falcon-7b-instruct": 8192, | |
| "timdettmers/guanaco-33b-merged": 2048, | |
| "mistralai/Mixtral-8x7B-Instruct-v0.1": 32768, | |
| "Qwen/Qwen2.5-72B-Instruct": 32768, | |
| "Qwen/Qwen2.5-Coder-32B-Instruct": 32768, | |
| "meta-llama/Meta-Llama-3-70B-Instruct": 8192, | |
| "CohereForAI/c4ai-command-r-plus-08-2024": 32768, | |
| "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 32768, | |
| "meta-llama/Llama-2-70b-chat-hf": 8192, | |
| "HuggingFaceH4/zephyr-7b-alpha": 17432, | |
| "HuggingFaceH4/zephyr-7b-beta": 32768, | |
| "mistralai/Mistral-7B-Instruct-v0.1": 32768, | |
| "mistralai/Mistral-7B-Instruct-v0.2": 32768, | |
| "mistralai/Mistral-7B-Instruct-v0.3": 32768, | |
| "mistralai/Mistral-Nemo-Instruct-2407": 32768, | |
| "meta-llama/Meta-Llama-3-8B-Instruct": 8192, | |
| "mistralai/Mistral-7B-v0.1": 32768, | |
| "bigcode/starcoder2-3b": 16384, | |
| "bigcode/starcoder2-15b": 16384, | |
| "HuggingFaceH4/starchat2-15b-v0.1": 16384, | |
| "codellama/CodeLlama-7b-hf": 8192, | |
| "codellama/CodeLlama-13b-hf": 8192, | |
| "codellama/CodeLlama-34b-Instruct-hf": 8192, | |
| "meta-llama/Llama-2-7b-chat-hf": 8192, | |
| "meta-llama/Llama-2-13b-chat-hf": 8192, | |
| "OpenAssistant/oasst-sft-6-llama-30b": 2048, | |
| "TheBloke/vicuna-7B-v1.5-GPTQ": 2048, | |
| "HuggingFaceH4/starchat-beta": 8192, | |
| "bigcode/octocoder": 8192, | |
| "vwxyzjn/starcoderbase-triviaqa": 8192, | |
| "lvwerra/starcoderbase-gsm8k": 8192, | |
| "NousResearch/Hermes-3-Llama-3.1-8B": 16384, | |
| "microsoft/Phi-3.5-mini-instruct": 32768, | |
| "meta-llama/Llama-3.1-70B-Instruct": 32768, | |
| "meta-llama/Llama-3.1-8B-Instruct": 8192, | |
| } as const; | |
| // Order of the elements in InferenceModal.svelte is determined by this const | |
| export const inferenceSnippetLanguages = ["python", "js", "sh"] as const; | |
| export type InferenceSnippetLanguage = (typeof inferenceSnippetLanguages)[number]; | |
| export type GetInferenceSnippetReturn = InferenceSnippet[]; | |
| export function getInferenceSnippet( | |
| conversation: ConversationClass, | |
| language: InferenceSnippetLanguage, | |
| opts?: { | |
| accessToken?: string; | |
| messages?: ConversationEntityMembers["messages"]; | |
| streaming?: ConversationEntityMembers["streaming"]; | |
| max_tokens?: ConversationEntityMembers["config"]["max_tokens"]; | |
| temperature?: ConversationEntityMembers["config"]["temperature"]; | |
| top_p?: ConversationEntityMembers["config"]["top_p"]; | |
| structured_output?: ConversationEntityMembers["structuredOutput"]; | |
| billTo?: string; | |
| }, | |
| ): GetInferenceSnippetReturn { | |
| const model = conversation.model; | |
| const data = conversation.data; | |
| const provider = (isCustomModel(model) ? "hf-inference" : data.provider) as Provider; | |
| // If it's a custom model, we don't generate inference snippets | |
| if (isCustomModel(model)) { | |
| return []; | |
| } | |
| const providerMapping = model.inferenceProviderMapping.find(p => p.provider === provider); | |
| if (!providerMapping && provider !== "auto") return []; | |
| const allSnippets = snippets.getInferenceSnippets( | |
| { ...model, inference: "" }, | |
| provider, | |
| // eslint-disable-next-line @typescript-eslint/no-explicit-any | |
| { ...providerMapping, hfModelId: model.id } as any, | |
| { ...opts, directRequest: false }, | |
| ); | |
| return allSnippets | |
| .filter(s => s.language === language) | |
| .map(s => { | |
| if ( | |
| opts?.structured_output?.schema && | |
| opts.structured_output.enabled && | |
| models.supportsStructuredOutput(conversation.model, provider) | |
| ) { | |
| return { | |
| ...s, | |
| content: modifySnippet(s.content, { | |
| response_format: getResponseFormatObj(conversation), | |
| }), | |
| }; | |
| } | |
| return s; | |
| }); | |
| } | |
| // eslint-disable-next-line svelte/prefer-svelte-reactivity | |
| const tokenizers = new Map<string, PreTrainedTokenizer | null>(); | |
| export async function getTokenizer(model: Model) { | |
| if (tokenizers.has(model.id)) return tokenizers.get(model.id)!; | |
| try { | |
| const tokenizer = await AutoTokenizer.from_pretrained(model.id); | |
| tokenizers.set(model.id, tokenizer); | |
| return tokenizer; | |
| } catch { | |
| tokenizers.set(model.id, null); | |
| return null; | |
| } | |
| } | |
| // When you don't have access to a tokenizer, guesstimate | |
| export function estimateTokens(conversation: ConversationClass) { | |
| if (!conversation.data.messages) return 0; | |
| const content = conversation.data.messages?.reduce((acc, curr) => { | |
| return acc + (curr?.content ?? ""); | |
| }, ""); | |
| return content.length / 4; // 1 token ~ 4 characters | |
| } | |
| export async function getTokens(conversation: ConversationClass): Promise<number> { | |
| const model = conversation.model; | |
| if (isCustomModel(model)) return estimateTokens(conversation); | |
| const tokenizer = await getTokenizer(model); | |
| if (tokenizer === null) return estimateTokens(conversation); | |
| // This is a simplified version - you might need to adjust based on your exact needs | |
| let formattedText = ""; | |
| conversation.data.messages?.forEach((message, index) => { | |
| let content = `<|start_header_id|>${message.role}<|end_header_id|>\n\n${message.content?.trim()}<|eot_id|>`; | |
| // Add BOS token to the first message | |
| if (index === 0) { | |
| content = "<|begin_of_text|>" + content; | |
| } | |
| formattedText += content; | |
| }); | |
| // Encode the text to get tokens | |
| const encodedInput = tokenizer.encode(formattedText); | |
| // Return the number of tokens | |
| return encodedInput.length; | |
| } | |