mishig HF staff commited on
Commit
d5e14b5
1 Parent(s): 5f94ff7

Improve config: add top_p, top_k, repetition_penalty and their correct ranges (#29)

Browse files
src/lib/components/InferencePlayground/InferencePlayground.svelte CHANGED
@@ -12,6 +12,7 @@
12
  import { onDestroy } from 'svelte';
13
  import { type ChatCompletionInputMessage } from '@huggingface/tasks';
14
  import type { ModelEntryWithTokenizer } from '$lib/types';
 
15
 
16
  export let models: ModelEntryWithTokenizer[];
17
 
@@ -21,8 +22,9 @@
21
  {
22
  id: String(Math.random()),
23
  model: models[0],
24
- config: { temperature: 0.5, maxTokens: 2048, streaming: true },
25
- messages: startMessages
 
26
  }
27
  ];
28
 
@@ -121,7 +123,7 @@
121
  ...conversation.messages
122
  ];
123
 
124
- if (conversation.config.streaming) {
125
  const streamingMessage = { role: 'assistant', content: '' };
126
  conversation.messages = [...conversation.messages, streamingMessage];
127
  const abortController = new AbortController();
@@ -400,7 +402,10 @@
400
  </select>
401
  </div>
402
 
403
- <PlaygroundOptions bind:config={conversations[0].config} />
 
 
 
404
  <div class="mt-auto">
405
  <div class="mb-3 flex items-center justify-between gap-2">
406
  <label
 
12
  import { onDestroy } from 'svelte';
13
  import { type ChatCompletionInputMessage } from '@huggingface/tasks';
14
  import type { ModelEntryWithTokenizer } from '$lib/types';
15
+ import { defaultGenerationConfig } from './generationConfigSettings';
16
 
17
  export let models: ModelEntryWithTokenizer[];
18
 
 
22
  {
23
  id: String(Math.random()),
24
  model: models[0],
25
+ config: defaultGenerationConfig,
26
+ messages: startMessages,
27
+ streaming: true
28
  }
29
  ];
30
 
 
123
  ...conversation.messages
124
  ];
125
 
126
+ if (conversation.streaming) {
127
  const streamingMessage = { role: 'assistant', content: '' };
128
  conversation.messages = [...conversation.messages, streamingMessage];
129
  const abortController = new AbortController();
 
402
  </select>
403
  </div>
404
 
405
+ <PlaygroundOptions
406
+ bind:config={conversations[0].config}
407
+ bind:streaming={conversations[0].streaming}
408
+ />
409
  <div class="mt-auto">
410
  <div class="mb-3 flex items-center justify-between gap-2">
411
  <label
src/lib/components/InferencePlayground/InferencePlaygroundCodeSnippets.svelte CHANGED
@@ -56,7 +56,7 @@
56
  # or
57
  yarn add @huggingface/inference`
58
  });
59
- if (conversation.config.streaming) {
60
  snippets.push({
61
  label: 'Streaming API',
62
  code: `import { HfInference } from "@huggingface/inference"
@@ -111,7 +111,7 @@ console.log(out.choices[0].message);`
111
  language: 'bash',
112
  code: `pip install huggingface_hub`
113
  });
114
- if (conversation.config.streaming) {
115
  snippets.push({
116
  label: 'Streaming API',
117
  code: `from huggingface_hub import InferenceClient
@@ -154,7 +154,7 @@ print(output.choices[0].message)`
154
  const messagesStr = getMessages();
155
  const snippets: Snippet[] = [];
156
 
157
- if (conversation.config.streaming) {
158
  snippets.push({
159
  label: 'Streaming API',
160
  code: `curl 'https://api-inference.huggingface.co/models/${conversation.model.id}/v1/chat/completions' \
 
56
  # or
57
  yarn add @huggingface/inference`
58
  });
59
+ if (conversation.streaming) {
60
  snippets.push({
61
  label: 'Streaming API',
62
  code: `import { HfInference } from "@huggingface/inference"
 
111
  language: 'bash',
112
  code: `pip install huggingface_hub`
113
  });
114
+ if (conversation.streaming) {
115
  snippets.push({
116
  label: 'Streaming API',
117
  code: `from huggingface_hub import InferenceClient
 
154
  const messagesStr = getMessages();
155
  const snippets: Snippet[] = [];
156
 
157
+ if (conversation.streaming) {
158
  snippets.push({
159
  label: 'Streaming API',
160
  code: `curl 'https://api-inference.huggingface.co/models/${conversation.model.id}/v1/chat/completions' \
src/lib/components/InferencePlayground/InferencePlaygroundConversation.svelte CHANGED
@@ -34,7 +34,7 @@
34
  <div
35
  class="flex max-h-[calc(100dvh-5.8rem)] flex-col overflow-y-auto overflow-x-hidden @container"
36
  class:pointer-events-none={loading}
37
- class:animate-pulse={loading && !conversation.config.streaming}
38
  bind:this={messageContainer}
39
  >
40
  {#if sideBySide}
@@ -65,6 +65,7 @@
65
  >
66
  <PlaygroundOptions
67
  bind:config={conversation.config}
 
68
  classNames="absolute top-8 right-0 w-56 invisible group-focus:visible hover:visible border border-gray-200/80 bg-white z-10 px-4 py-6 text-sm shadow-sm dark:border-gray-800 dark:bg-gray-800 rounded-xl"
69
  />
70
  </button>
 
34
  <div
35
  class="flex max-h-[calc(100dvh-5.8rem)] flex-col overflow-y-auto overflow-x-hidden @container"
36
  class:pointer-events-none={loading}
37
+ class:animate-pulse={loading && !conversation.streaming}
38
  bind:this={messageContainer}
39
  >
40
  {#if sideBySide}
 
65
  >
66
  <PlaygroundOptions
67
  bind:config={conversation.config}
68
+ bind:streaming={conversation.streaming}
69
  classNames="absolute top-8 right-0 w-56 invisible group-focus:visible hover:visible border border-gray-200/80 bg-white z-10 px-4 py-6 text-sm shadow-sm dark:border-gray-800 dark:bg-gray-800 rounded-xl"
70
  />
71
  </button>
src/lib/components/InferencePlayground/InferencePlaygroundGenerationConfig.svelte CHANGED
@@ -1,62 +1,45 @@
1
  <script lang="ts">
 
 
2
  export let config;
 
3
  export let classNames = '';
4
  </script>
5
 
6
- <div class={classNames}>
7
- <div>
8
- <div class="flex items-center justify-between">
9
- <label
10
- for="temperature-range"
11
- class="mb-2 block text-sm font-medium text-gray-900 dark:text-white">Temperature</label
12
- >
 
 
 
 
 
 
 
 
 
 
 
 
13
  <input
14
- type="number"
15
- class="w-16 rounded border bg-transparent px-1 py-0.5 text-right text-sm dark:border-gray-700"
16
- bind:value={config.temperature}
17
- min="0"
18
- max="1"
19
- step="0.1"
 
20
  />
21
  </div>
22
- <input
23
- id="temperature-range"
24
- type="range"
25
- bind:value={config.temperature}
26
- min="0"
27
- max="1"
28
- step="0.1"
29
- class="h-2 w-full cursor-pointer appearance-none rounded-lg bg-gray-200 accent-black dark:bg-gray-700 dark:accent-blue-500"
30
- />
31
- </div>
32
- <div>
33
- <div class="flex items-center justify-between">
34
- <label
35
- for="max-tokens-range"
36
- class="mb-2 block text-sm font-medium text-gray-900 dark:text-white">Max tokens</label
37
- >
38
- <input
39
- type="number"
40
- class="w-20 rounded border bg-transparent px-1 py-0.5 text-right text-sm dark:border-gray-700"
41
- bind:value={config.maxTokens}
42
- min="0"
43
- max="4096"
44
- step="512"
45
- />
46
- </div>
47
- <input
48
- id="max-tokens-range"
49
- type="range"
50
- bind:value={config.maxTokens}
51
- min="0"
52
- max="4096"
53
- step="512"
54
- class="h-2 w-full cursor-pointer appearance-none rounded-lg bg-gray-200 accent-black dark:bg-gray-700 dark:accent-blue-500"
55
- />
56
- </div>
57
  <div class="mt-2">
58
  <label class="flex cursor-pointer items-center justify-between">
59
- <input type="checkbox" bind:checked={config.streaming} class="peer sr-only" />
60
  <span class="text-sm font-medium text-gray-900 dark:text-gray-300">Streaming</span>
61
  <div
62
  class="peer relative h-5 w-9 rounded-full bg-gray-200 after:absolute after:start-[2px] after:top-[2px] after:h-4 after:w-4 after:rounded-full after:border after:border-gray-300 after:bg-white after:transition-all after:content-[''] peer-checked:bg-black peer-checked:after:translate-x-full peer-checked:after:border-white peer-focus:outline-none dark:border-gray-600 dark:bg-gray-700 dark:peer-checked:bg-blue-600"
 
1
  <script lang="ts">
2
+ import { GENERATION_CONFIG_KEYS, GENERATION_CONFIG_SETTINGS } from './generationConfigSettings';
3
+
4
  export let config;
5
+ export let streaming;
6
  export let classNames = '';
7
  </script>
8
 
9
+ <div class="flex flex-col gap-y-5 {classNames}">
10
+ {#each GENERATION_CONFIG_KEYS as key}
11
+ {@const settings = GENERATION_CONFIG_SETTINGS[key]}
12
+ <div>
13
+ <div class="flex items-center justify-between">
14
+ <label
15
+ for="temperature-range"
16
+ class="mb-2 block text-sm font-medium text-gray-900 dark:text-white"
17
+ >{settings.label}</label
18
+ >
19
+ <input
20
+ type="number"
21
+ class="w-16 rounded border bg-transparent px-1 py-0.5 text-right text-sm dark:border-gray-700"
22
+ min={settings.min}
23
+ max={settings.max}
24
+ step={settings.step}
25
+ bind:value={config[key]}
26
+ />
27
+ </div>
28
  <input
29
+ id="temperature-range"
30
+ type="range"
31
+ min={settings.min}
32
+ max={settings.max}
33
+ step={settings.step}
34
+ bind:value={config[key]}
35
+ class="h-2 w-full cursor-pointer appearance-none rounded-lg bg-gray-200 accent-black dark:bg-gray-700 dark:accent-blue-500"
36
  />
37
  </div>
38
+ {/each}
39
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  <div class="mt-2">
41
  <label class="flex cursor-pointer items-center justify-between">
42
+ <input type="checkbox" bind:checked={streaming} class="peer sr-only" />
43
  <span class="text-sm font-medium text-gray-900 dark:text-gray-300">Streaming</span>
44
  <div
45
  class="peer relative h-5 w-9 rounded-full bg-gray-200 after:absolute after:start-[2px] after:top-[2px] after:h-4 after:w-4 after:rounded-full after:border after:border-gray-300 after:bg-white after:transition-all after:content-[''] peer-checked:bg-black peer-checked:after:translate-x-full peer-checked:after:border-white peer-focus:outline-none dark:border-gray-600 dark:bg-gray-700 dark:peer-checked:bg-blue-600"
src/lib/components/InferencePlayground/generationConfigSettings.ts ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ interface GenerationKeySettings {
2
+ default: number;
3
+ step: number;
4
+ min: number;
5
+ max: number;
6
+ label: string;
7
+ }
8
+
9
+ export const GENERATION_CONFIG_SETTINGS: Record<string, GenerationKeySettings> = {
10
+ temperature: {
11
+ default: 0.7,
12
+ step: 0.01,
13
+ min: 0,
14
+ max: 2,
15
+ label: 'Temperature'
16
+ },
17
+ max_tokens: {
18
+ default: 0.7,
19
+ step: 512,
20
+ min: 1,
21
+ max: 8192, // changed dynamically based on model
22
+ label: 'Output Length'
23
+ },
24
+ top_p: {
25
+ default: 0.7,
26
+ step: 0.01,
27
+ min: 0,
28
+ max: 1,
29
+ label: 'Top-P'
30
+ },
31
+ top_k: {
32
+ default: 50,
33
+ step: 1,
34
+ min: 1,
35
+ max: 100,
36
+ label: 'Top-K'
37
+ },
38
+ repetition_penalty: {
39
+ default: 1,
40
+ step: 0.01,
41
+ min: 1,
42
+ max: 2,
43
+ label: 'Repetition Penalty'
44
+ }
45
+ };
46
+
47
+ export type GenerationConfigKey = keyof typeof GENERATION_CONFIG_SETTINGS;
48
+
49
+ export const GENERATION_CONFIG_KEYS: GenerationConfigKey[] = Object.keys(
50
+ GENERATION_CONFIG_SETTINGS
51
+ );
52
+
53
+ export type GenerationConfig = Record<GenerationConfigKey, number>;
54
+
55
+ export const defaultGenerationConfig = Object.keys(GENERATION_CONFIG_SETTINGS).reduce(
56
+ (acc, key) => {
57
+ acc[key] = GENERATION_CONFIG_SETTINGS[key].default;
58
+ return acc;
59
+ },
60
+ {} as GenerationConfig
61
+ );
src/lib/types/index.d.ts CHANGED
@@ -1,19 +1,13 @@
 
1
  import type { ModelEntry } from '@huggingface/hub';
2
  import type { ChatCompletionInputMessage } from '@huggingface/tasks';
3
 
4
- type Model = string;
5
-
6
- type GenerationConfig = {
7
- temperature: number;
8
- maxTokens: number;
9
- streaming: boolean;
10
- };
11
-
12
  type Conversation = {
13
  id: string;
14
  model: ModelEntryWithTokenizer;
15
  config: GenerationConfig;
16
  messages: ChatCompletionInputMessage[];
 
17
  };
18
 
19
  interface TokenizerConfig {
 
1
+ import type { GenerationConfig } from '$lib/components/InferencePlayground/generationConfigSettings';
2
  import type { ModelEntry } from '@huggingface/hub';
3
  import type { ChatCompletionInputMessage } from '@huggingface/tasks';
4
 
 
 
 
 
 
 
 
 
5
  type Conversation = {
6
  id: string;
7
  model: ModelEntryWithTokenizer;
8
  config: GenerationConfig;
9
  messages: ChatCompletionInputMessage[];
10
+ streaming: boolean;
11
  };
12
 
13
  interface TokenizerConfig {