File size: 2,290 Bytes
60216ec af93b45 283bd45 0fa6cab 8c7e6f1 f977d49 8c7e6f1 f977d49 573aa88 2ac97e2 b61328c 8c7e6f1 b61328c 573aa88 b61328c 60216ec 573aa88 60216ec 5f94ff7 b61328c 5f94ff7 60216ec 5f94ff7 c18e96c 5f94ff7 f977d49 8c7e6f1 f977d49 b61328c 51a1671 b61328c 573aa88 b61328c 60216ec 573aa88 c18e96c 8c7e6f1 f977d49 51a1671 f977d49 60216ec 8c7e6f1 dd66861 60216ec dd66861 d47c403 dee0245 d47c403 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import { type ChatCompletionInputMessage } from "@huggingface/tasks";
import type { Conversation, ModelEntryWithTokenizer } from "./types";
import { HfInference } from "@huggingface/inference";
export function createHfInference(token: string): HfInference {
return new HfInference(token);
}
export async function handleStreamingResponse(
hf: HfInference,
conversation: Conversation,
onChunk: (content: string) => void,
abortController: AbortController
): Promise<void> {
const { model, systemMessage } = conversation;
const messages = [
...(isSystemPromptSupported(model) && systemMessage.content?.length ? [systemMessage] : []),
...conversation.messages,
];
let out = "";
for await (const chunk of hf.chatCompletionStream(
{
model: model.id,
messages,
temperature: conversation.config.temperature,
max_tokens: conversation.config.maxTokens,
},
{ signal: abortController.signal, use_cache: false }
)) {
if (chunk.choices && chunk.choices.length > 0 && chunk.choices[0]?.delta?.content) {
out += chunk.choices[0].delta.content;
onChunk(out);
}
}
}
export async function handleNonStreamingResponse(
hf: HfInference,
conversation: Conversation
): Promise<{ message: ChatCompletionInputMessage; completion_tokens: number }> {
const { model, systemMessage } = conversation;
const messages = [
...(isSystemPromptSupported(model) && systemMessage.content?.length ? [systemMessage] : []),
...conversation.messages,
];
const response = await hf.chatCompletion(
{
model: model.id,
messages,
temperature: conversation.config.temperature,
max_tokens: conversation.config.maxTokens,
},
{ use_cache: false }
);
if (response.choices && response.choices.length > 0) {
const { message } = response.choices[0];
const { completion_tokens } = response.usage;
return { message, completion_tokens };
}
throw new Error("No response from the model");
}
export function isSystemPromptSupported(model: ModelEntryWithTokenizer) {
return model.tokenizerConfig?.chat_template?.includes("system");
}
export const FEATUED_MODELS_IDS = [
"meta-llama/Meta-Llama-3.1-70B-Instruct",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"google/gemma-2-9b-it",
"mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mistral-Nemo-Instruct-2407",
];
|