mishig HF staff commited on
Commit
573aa88
·
unverified ·
1 Parent(s): fd28154

Improve inference functions (#28)

Browse files
src/lib/components/InferencePlayground/InferencePlayground.svelte CHANGED
@@ -129,10 +129,7 @@
129
 
130
  await handleStreamingResponse(
131
  hf,
132
- conversation.model,
133
- requestMessages,
134
- conversation.config.temperature,
135
- conversation.config.maxTokens,
136
  (content) => {
137
  if (streamingMessage) {
138
  streamingMessage.content = content;
@@ -140,17 +137,12 @@
140
  conversations = conversations;
141
  }
142
  },
143
- abortController
 
144
  );
145
  } else {
146
  waitForNonStreaming = true;
147
- const newMessage = await handleNonStreamingResponse(
148
- hf,
149
- conversation.model,
150
- requestMessages,
151
- conversation.config.temperature,
152
- conversation.config.maxTokens
153
- );
154
  // check if the user did not abort the request
155
  if (waitForNonStreaming) {
156
  conversation.messages = [...conversation.messages, newMessage];
 
129
 
130
  await handleStreamingResponse(
131
  hf,
132
+ conversation,
 
 
 
133
  (content) => {
134
  if (streamingMessage) {
135
  streamingMessage.content = content;
 
137
  conversations = conversations;
138
  }
139
  },
140
+ abortController,
141
+ systemMessage
142
  );
143
  } else {
144
  waitForNonStreaming = true;
145
+ const newMessage = await handleNonStreamingResponse(hf, conversation, systemMessage);
 
 
 
 
 
 
146
  // check if the user did not abort the request
147
  if (waitForNonStreaming) {
148
  conversation.messages = [...conversation.messages, newMessage];
src/lib/components/InferencePlayground/inferencePlaygroundUtils.ts CHANGED
@@ -1,6 +1,6 @@
1
  import { type ChatCompletionInputMessage } from '@huggingface/tasks';
2
  import { HfInference } from '@huggingface/inference';
3
- import type { ModelEntryWithTokenizer } from '$lib/types';
4
 
5
  export function createHfInference(token: string): HfInference {
6
  return new HfInference(token);
@@ -8,21 +8,25 @@ export function createHfInference(token: string): HfInference {
8
 
9
  export async function handleStreamingResponse(
10
  hf: HfInference,
11
- model: string,
12
- messages: ChatCompletionInputMessage[],
13
- temperature: number,
14
- maxTokens: number,
15
  onChunk: (content: string) => void,
16
- abortController: AbortController
 
17
  ): Promise<void> {
 
 
 
 
 
 
18
  let out = '';
19
  try {
20
  for await (const chunk of hf.chatCompletionStream(
21
  {
22
- model: model,
23
- messages: messages,
24
- temperature: temperature,
25
- max_tokens: maxTokens
26
  },
27
  { signal: abortController.signal }
28
  )) {
@@ -42,16 +46,21 @@ export async function handleStreamingResponse(
42
 
43
  export async function handleNonStreamingResponse(
44
  hf: HfInference,
45
- model: string,
46
- messages: ChatCompletionInputMessage[],
47
- temperature: number,
48
- maxTokens: number
49
  ): Promise<ChatCompletionInputMessage> {
 
 
 
 
 
 
 
50
  const response = await hf.chatCompletion({
51
- model: model,
52
- messages: messages,
53
- temperature: temperature,
54
- max_tokens: maxTokens
55
  });
56
 
57
  if (response.choices && response.choices.length > 0) {
 
1
  import { type ChatCompletionInputMessage } from '@huggingface/tasks';
2
  import { HfInference } from '@huggingface/inference';
3
+ import type { Conversation, ModelEntryWithTokenizer } from '$lib/types';
4
 
5
  export function createHfInference(token: string): HfInference {
6
  return new HfInference(token);
 
8
 
9
  export async function handleStreamingResponse(
10
  hf: HfInference,
11
+ conversation: Conversation,
 
 
 
12
  onChunk: (content: string) => void,
13
+ abortController: AbortController,
14
+ systemMessage?: ChatCompletionInputMessage
15
  ): Promise<void> {
16
+ const messages = [
17
+ ...(isSystemPromptSupported(conversation.model) && systemMessage?.content?.length
18
+ ? [systemMessage]
19
+ : []),
20
+ ...conversation.messages
21
+ ];
22
  let out = '';
23
  try {
24
  for await (const chunk of hf.chatCompletionStream(
25
  {
26
+ model: conversation.model.id,
27
+ messages,
28
+ temperature: conversation.config.temperature,
29
+ max_tokens: conversation.config.maxTokens
30
  },
31
  { signal: abortController.signal }
32
  )) {
 
46
 
47
  export async function handleNonStreamingResponse(
48
  hf: HfInference,
49
+ conversation: Conversation,
50
+ systemMessage?: ChatCompletionInputMessage
 
 
51
  ): Promise<ChatCompletionInputMessage> {
52
+ const messages = [
53
+ ...(isSystemPromptSupported(conversation.model) && systemMessage?.content?.length
54
+ ? [systemMessage]
55
+ : []),
56
+ ...conversation.messages
57
+ ];
58
+
59
  const response = await hf.chatCompletion({
60
+ model: conversation.model,
61
+ messages,
62
+ temperature: conversation.config.temperature,
63
+ max_tokens: conversation.config.maxTokens
64
  });
65
 
66
  if (response.choices && response.choices.length > 0) {