Improve config: add top_p, top_k, repetition_penalty and their correct ranges (#29)
Browse files- src/lib/components/InferencePlayground/InferencePlayground.svelte +9 -4
- src/lib/components/InferencePlayground/InferencePlaygroundCodeSnippets.svelte +3 -3
- src/lib/components/InferencePlayground/InferencePlaygroundConversation.svelte +2 -1
- src/lib/components/InferencePlayground/InferencePlaygroundGenerationConfig.svelte +32 -49
- src/lib/components/InferencePlayground/generationConfigSettings.ts +61 -0
- src/lib/types/index.d.ts +2 -8
src/lib/components/InferencePlayground/InferencePlayground.svelte
CHANGED
@@ -12,6 +12,7 @@
|
|
12 |
import { onDestroy } from 'svelte';
|
13 |
import { type ChatCompletionInputMessage } from '@huggingface/tasks';
|
14 |
import type { ModelEntryWithTokenizer } from '$lib/types';
|
|
|
15 |
|
16 |
export let models: ModelEntryWithTokenizer[];
|
17 |
|
@@ -21,8 +22,9 @@
|
|
21 |
{
|
22 |
id: String(Math.random()),
|
23 |
model: models[0],
|
24 |
-
config:
|
25 |
-
messages: startMessages
|
|
|
26 |
}
|
27 |
];
|
28 |
|
@@ -121,7 +123,7 @@
|
|
121 |
...conversation.messages
|
122 |
];
|
123 |
|
124 |
-
if (conversation.
|
125 |
const streamingMessage = { role: 'assistant', content: '' };
|
126 |
conversation.messages = [...conversation.messages, streamingMessage];
|
127 |
const abortController = new AbortController();
|
@@ -400,7 +402,10 @@
|
|
400 |
</select>
|
401 |
</div>
|
402 |
|
403 |
-
<PlaygroundOptions
|
|
|
|
|
|
|
404 |
<div class="mt-auto">
|
405 |
<div class="mb-3 flex items-center justify-between gap-2">
|
406 |
<label
|
|
|
12 |
import { onDestroy } from 'svelte';
|
13 |
import { type ChatCompletionInputMessage } from '@huggingface/tasks';
|
14 |
import type { ModelEntryWithTokenizer } from '$lib/types';
|
15 |
+
import { defaultGenerationConfig } from './generationConfigSettings';
|
16 |
|
17 |
export let models: ModelEntryWithTokenizer[];
|
18 |
|
|
|
22 |
{
|
23 |
id: String(Math.random()),
|
24 |
model: models[0],
|
25 |
+
config: defaultGenerationConfig,
|
26 |
+
messages: startMessages,
|
27 |
+
streaming: true
|
28 |
}
|
29 |
];
|
30 |
|
|
|
123 |
...conversation.messages
|
124 |
];
|
125 |
|
126 |
+
if (conversation.streaming) {
|
127 |
const streamingMessage = { role: 'assistant', content: '' };
|
128 |
conversation.messages = [...conversation.messages, streamingMessage];
|
129 |
const abortController = new AbortController();
|
|
|
402 |
</select>
|
403 |
</div>
|
404 |
|
405 |
+
<PlaygroundOptions
|
406 |
+
bind:config={conversations[0].config}
|
407 |
+
bind:streaming={conversations[0].streaming}
|
408 |
+
/>
|
409 |
<div class="mt-auto">
|
410 |
<div class="mb-3 flex items-center justify-between gap-2">
|
411 |
<label
|
src/lib/components/InferencePlayground/InferencePlaygroundCodeSnippets.svelte
CHANGED
@@ -56,7 +56,7 @@
|
|
56 |
# or
|
57 |
yarn add @huggingface/inference`
|
58 |
});
|
59 |
-
if (conversation.
|
60 |
snippets.push({
|
61 |
label: 'Streaming API',
|
62 |
code: `import { HfInference } from "@huggingface/inference"
|
@@ -111,7 +111,7 @@ console.log(out.choices[0].message);`
|
|
111 |
language: 'bash',
|
112 |
code: `pip install huggingface_hub`
|
113 |
});
|
114 |
-
if (conversation.
|
115 |
snippets.push({
|
116 |
label: 'Streaming API',
|
117 |
code: `from huggingface_hub import InferenceClient
|
@@ -154,7 +154,7 @@ print(output.choices[0].message)`
|
|
154 |
const messagesStr = getMessages();
|
155 |
const snippets: Snippet[] = [];
|
156 |
|
157 |
-
if (conversation.
|
158 |
snippets.push({
|
159 |
label: 'Streaming API',
|
160 |
code: `curl 'https://api-inference.huggingface.co/models/${conversation.model.id}/v1/chat/completions' \
|
|
|
56 |
# or
|
57 |
yarn add @huggingface/inference`
|
58 |
});
|
59 |
+
if (conversation.streaming) {
|
60 |
snippets.push({
|
61 |
label: 'Streaming API',
|
62 |
code: `import { HfInference } from "@huggingface/inference"
|
|
|
111 |
language: 'bash',
|
112 |
code: `pip install huggingface_hub`
|
113 |
});
|
114 |
+
if (conversation.streaming) {
|
115 |
snippets.push({
|
116 |
label: 'Streaming API',
|
117 |
code: `from huggingface_hub import InferenceClient
|
|
|
154 |
const messagesStr = getMessages();
|
155 |
const snippets: Snippet[] = [];
|
156 |
|
157 |
+
if (conversation.streaming) {
|
158 |
snippets.push({
|
159 |
label: 'Streaming API',
|
160 |
code: `curl 'https://api-inference.huggingface.co/models/${conversation.model.id}/v1/chat/completions' \
|
src/lib/components/InferencePlayground/InferencePlaygroundConversation.svelte
CHANGED
@@ -34,7 +34,7 @@
|
|
34 |
<div
|
35 |
class="flex max-h-[calc(100dvh-5.8rem)] flex-col overflow-y-auto overflow-x-hidden @container"
|
36 |
class:pointer-events-none={loading}
|
37 |
-
class:animate-pulse={loading && !conversation.
|
38 |
bind:this={messageContainer}
|
39 |
>
|
40 |
{#if sideBySide}
|
@@ -65,6 +65,7 @@
|
|
65 |
>
|
66 |
<PlaygroundOptions
|
67 |
bind:config={conversation.config}
|
|
|
68 |
classNames="absolute top-8 right-0 w-56 invisible group-focus:visible hover:visible border border-gray-200/80 bg-white z-10 px-4 py-6 text-sm shadow-sm dark:border-gray-800 dark:bg-gray-800 rounded-xl"
|
69 |
/>
|
70 |
</button>
|
|
|
34 |
<div
|
35 |
class="flex max-h-[calc(100dvh-5.8rem)] flex-col overflow-y-auto overflow-x-hidden @container"
|
36 |
class:pointer-events-none={loading}
|
37 |
+
class:animate-pulse={loading && !conversation.streaming}
|
38 |
bind:this={messageContainer}
|
39 |
>
|
40 |
{#if sideBySide}
|
|
|
65 |
>
|
66 |
<PlaygroundOptions
|
67 |
bind:config={conversation.config}
|
68 |
+
bind:streaming={conversation.streaming}
|
69 |
classNames="absolute top-8 right-0 w-56 invisible group-focus:visible hover:visible border border-gray-200/80 bg-white z-10 px-4 py-6 text-sm shadow-sm dark:border-gray-800 dark:bg-gray-800 rounded-xl"
|
70 |
/>
|
71 |
</button>
|
src/lib/components/InferencePlayground/InferencePlaygroundGenerationConfig.svelte
CHANGED
@@ -1,62 +1,45 @@
|
|
1 |
<script lang="ts">
|
|
|
|
|
2 |
export let config;
|
|
|
3 |
export let classNames = '';
|
4 |
</script>
|
5 |
|
6 |
-
<div class={classNames}>
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
<input
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
20 |
/>
|
21 |
</div>
|
22 |
-
|
23 |
-
|
24 |
-
type="range"
|
25 |
-
bind:value={config.temperature}
|
26 |
-
min="0"
|
27 |
-
max="1"
|
28 |
-
step="0.1"
|
29 |
-
class="h-2 w-full cursor-pointer appearance-none rounded-lg bg-gray-200 accent-black dark:bg-gray-700 dark:accent-blue-500"
|
30 |
-
/>
|
31 |
-
</div>
|
32 |
-
<div>
|
33 |
-
<div class="flex items-center justify-between">
|
34 |
-
<label
|
35 |
-
for="max-tokens-range"
|
36 |
-
class="mb-2 block text-sm font-medium text-gray-900 dark:text-white">Max tokens</label
|
37 |
-
>
|
38 |
-
<input
|
39 |
-
type="number"
|
40 |
-
class="w-20 rounded border bg-transparent px-1 py-0.5 text-right text-sm dark:border-gray-700"
|
41 |
-
bind:value={config.maxTokens}
|
42 |
-
min="0"
|
43 |
-
max="4096"
|
44 |
-
step="512"
|
45 |
-
/>
|
46 |
-
</div>
|
47 |
-
<input
|
48 |
-
id="max-tokens-range"
|
49 |
-
type="range"
|
50 |
-
bind:value={config.maxTokens}
|
51 |
-
min="0"
|
52 |
-
max="4096"
|
53 |
-
step="512"
|
54 |
-
class="h-2 w-full cursor-pointer appearance-none rounded-lg bg-gray-200 accent-black dark:bg-gray-700 dark:accent-blue-500"
|
55 |
-
/>
|
56 |
-
</div>
|
57 |
<div class="mt-2">
|
58 |
<label class="flex cursor-pointer items-center justify-between">
|
59 |
-
<input type="checkbox" bind:checked={
|
60 |
<span class="text-sm font-medium text-gray-900 dark:text-gray-300">Streaming</span>
|
61 |
<div
|
62 |
class="peer relative h-5 w-9 rounded-full bg-gray-200 after:absolute after:start-[2px] after:top-[2px] after:h-4 after:w-4 after:rounded-full after:border after:border-gray-300 after:bg-white after:transition-all after:content-[''] peer-checked:bg-black peer-checked:after:translate-x-full peer-checked:after:border-white peer-focus:outline-none dark:border-gray-600 dark:bg-gray-700 dark:peer-checked:bg-blue-600"
|
|
|
1 |
<script lang="ts">
|
2 |
+
import { GENERATION_CONFIG_KEYS, GENERATION_CONFIG_SETTINGS } from './generationConfigSettings';
|
3 |
+
|
4 |
export let config;
|
5 |
+
export let streaming;
|
6 |
export let classNames = '';
|
7 |
</script>
|
8 |
|
9 |
+
<div class="flex flex-col gap-y-5 {classNames}">
|
10 |
+
{#each GENERATION_CONFIG_KEYS as key}
|
11 |
+
{@const settings = GENERATION_CONFIG_SETTINGS[key]}
|
12 |
+
<div>
|
13 |
+
<div class="flex items-center justify-between">
|
14 |
+
<label
|
15 |
+
for="temperature-range"
|
16 |
+
class="mb-2 block text-sm font-medium text-gray-900 dark:text-white"
|
17 |
+
>{settings.label}</label
|
18 |
+
>
|
19 |
+
<input
|
20 |
+
type="number"
|
21 |
+
class="w-16 rounded border bg-transparent px-1 py-0.5 text-right text-sm dark:border-gray-700"
|
22 |
+
min={settings.min}
|
23 |
+
max={settings.max}
|
24 |
+
step={settings.step}
|
25 |
+
bind:value={config[key]}
|
26 |
+
/>
|
27 |
+
</div>
|
28 |
<input
|
29 |
+
id="temperature-range"
|
30 |
+
type="range"
|
31 |
+
min={settings.min}
|
32 |
+
max={settings.max}
|
33 |
+
step={settings.step}
|
34 |
+
bind:value={config[key]}
|
35 |
+
class="h-2 w-full cursor-pointer appearance-none rounded-lg bg-gray-200 accent-black dark:bg-gray-700 dark:accent-blue-500"
|
36 |
/>
|
37 |
</div>
|
38 |
+
{/each}
|
39 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
<div class="mt-2">
|
41 |
<label class="flex cursor-pointer items-center justify-between">
|
42 |
+
<input type="checkbox" bind:checked={streaming} class="peer sr-only" />
|
43 |
<span class="text-sm font-medium text-gray-900 dark:text-gray-300">Streaming</span>
|
44 |
<div
|
45 |
class="peer relative h-5 w-9 rounded-full bg-gray-200 after:absolute after:start-[2px] after:top-[2px] after:h-4 after:w-4 after:rounded-full after:border after:border-gray-300 after:bg-white after:transition-all after:content-[''] peer-checked:bg-black peer-checked:after:translate-x-full peer-checked:after:border-white peer-focus:outline-none dark:border-gray-600 dark:bg-gray-700 dark:peer-checked:bg-blue-600"
|
src/lib/components/InferencePlayground/generationConfigSettings.ts
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
interface GenerationKeySettings {
|
2 |
+
default: number;
|
3 |
+
step: number;
|
4 |
+
min: number;
|
5 |
+
max: number;
|
6 |
+
label: string;
|
7 |
+
}
|
8 |
+
|
9 |
+
export const GENERATION_CONFIG_SETTINGS: Record<string, GenerationKeySettings> = {
|
10 |
+
temperature: {
|
11 |
+
default: 0.7,
|
12 |
+
step: 0.01,
|
13 |
+
min: 0,
|
14 |
+
max: 2,
|
15 |
+
label: 'Temperature'
|
16 |
+
},
|
17 |
+
max_tokens: {
|
18 |
+
default: 0.7,
|
19 |
+
step: 512,
|
20 |
+
min: 1,
|
21 |
+
max: 8192, // changed dynamically based on model
|
22 |
+
label: 'Output Length'
|
23 |
+
},
|
24 |
+
top_p: {
|
25 |
+
default: 0.7,
|
26 |
+
step: 0.01,
|
27 |
+
min: 0,
|
28 |
+
max: 1,
|
29 |
+
label: 'Top-P'
|
30 |
+
},
|
31 |
+
top_k: {
|
32 |
+
default: 50,
|
33 |
+
step: 1,
|
34 |
+
min: 1,
|
35 |
+
max: 100,
|
36 |
+
label: 'Top-K'
|
37 |
+
},
|
38 |
+
repetition_penalty: {
|
39 |
+
default: 1,
|
40 |
+
step: 0.01,
|
41 |
+
min: 1,
|
42 |
+
max: 2,
|
43 |
+
label: 'Repetition Penalty'
|
44 |
+
}
|
45 |
+
};
|
46 |
+
|
47 |
+
export type GenerationConfigKey = keyof typeof GENERATION_CONFIG_SETTINGS;
|
48 |
+
|
49 |
+
export const GENERATION_CONFIG_KEYS: GenerationConfigKey[] = Object.keys(
|
50 |
+
GENERATION_CONFIG_SETTINGS
|
51 |
+
);
|
52 |
+
|
53 |
+
export type GenerationConfig = Record<GenerationConfigKey, number>;
|
54 |
+
|
55 |
+
export const defaultGenerationConfig = Object.keys(GENERATION_CONFIG_SETTINGS).reduce(
|
56 |
+
(acc, key) => {
|
57 |
+
acc[key] = GENERATION_CONFIG_SETTINGS[key].default;
|
58 |
+
return acc;
|
59 |
+
},
|
60 |
+
{} as GenerationConfig
|
61 |
+
);
|
src/lib/types/index.d.ts
CHANGED
@@ -1,19 +1,13 @@
|
|
|
|
1 |
import type { ModelEntry } from '@huggingface/hub';
|
2 |
import type { ChatCompletionInputMessage } from '@huggingface/tasks';
|
3 |
|
4 |
-
type Model = string;
|
5 |
-
|
6 |
-
type GenerationConfig = {
|
7 |
-
temperature: number;
|
8 |
-
maxTokens: number;
|
9 |
-
streaming: boolean;
|
10 |
-
};
|
11 |
-
|
12 |
type Conversation = {
|
13 |
id: string;
|
14 |
model: ModelEntryWithTokenizer;
|
15 |
config: GenerationConfig;
|
16 |
messages: ChatCompletionInputMessage[];
|
|
|
17 |
};
|
18 |
|
19 |
interface TokenizerConfig {
|
|
|
1 |
+
import type { GenerationConfig } from '$lib/components/InferencePlayground/generationConfigSettings';
|
2 |
import type { ModelEntry } from '@huggingface/hub';
|
3 |
import type { ChatCompletionInputMessage } from '@huggingface/tasks';
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
type Conversation = {
|
6 |
id: string;
|
7 |
model: ModelEntryWithTokenizer;
|
8 |
config: GenerationConfig;
|
9 |
messages: ChatCompletionInputMessage[];
|
10 |
+
streaming: boolean;
|
11 |
};
|
12 |
|
13 |
interface TokenizerConfig {
|