nsarrazin HF staff victor HF staff Mishig commited on
Commit
d4016bc
1 Parent(s): e9ad67e

Expose sampling controls in assistants (#955) (#959)

Browse files

* Expose sampling controls in assistants (#955)

* Make sure all labels have the same font size

* styling

* Add better tooltips

* better padding & wrapping

* Revert "better padding & wrapping"

This reverts commit 1b44086465040f2cb6bc906983cfc8d95820d6fe.

* ui update

* tooltip on mobile

* lint

* Update src/lib/components/AssistantSettings.svelte

Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>

---------

Co-authored-by: Victor Mustar <victor.mustar@gmail.com>
Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>

src/lib/components/AssistantSettings.svelte CHANGED
@@ -9,11 +9,14 @@
9
  import { base } from "$app/paths";
10
  import CarbonPen from "~icons/carbon/pen";
11
  import CarbonUpload from "~icons/carbon/upload";
 
 
12
 
13
  import { useSettingsStore } from "$lib/stores/settings";
14
  import { isHuggingChat } from "$lib/utils/isHuggingChat";
15
  import IconInternet from "./icons/IconInternet.svelte";
16
  import TokensCounter from "./TokensCounter.svelte";
 
17
 
18
  type ActionData = {
19
  error: boolean;
@@ -31,16 +34,22 @@
31
 
32
  let files: FileList | null = null;
33
  const settings = useSettingsStore();
34
- let modelId =
35
- assistant?.modelId ?? models.find((_model) => _model.id === $settings.activeModel)?.name;
36
  let systemPrompt = assistant?.preprompt ?? "";
37
  let dynamicPrompt = assistant?.dynamicPrompt ?? false;
 
38
 
39
  let compress: typeof readAndCompressImage | null = null;
40
 
41
  onMount(async () => {
42
  const module = await import("browser-image-resizer");
43
  compress = module.readAndCompressImage;
 
 
 
 
 
 
44
  });
45
 
46
  let inputMessage1 = assistant?.exampleInputs[0] ?? "";
@@ -89,11 +98,12 @@
89
 
90
  const regex = /{{\s?url=(.+?)\s?}}/g;
91
  $: templateVariables = [...systemPrompt.matchAll(regex)].map((match) => match[1]);
 
92
  </script>
93
 
94
  <form
95
  method="POST"
96
- class="flex h-full flex-col overflow-y-auto p-4 md:p-8"
97
  enctype="multipart/form-data"
98
  use:enhance={async ({ formData }) => {
99
  loading = true;
@@ -246,21 +256,122 @@
246
 
247
  <label>
248
  <div class="mb-1 font-semibold">Model</div>
249
- <select
250
- name="modelId"
251
- class="w-full rounded-lg border-2 border-gray-200 bg-gray-100 p-2"
252
- bind:value={modelId}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  >
254
- {#each models.filter((model) => !model.unlisted) as model}
255
- <option
256
- value={model.id}
257
- selected={assistant
258
- ? assistant?.modelId === model.id
259
- : $settings.activeModel === model.id}>{model.displayName}</option
260
- >
261
- {/each}
262
- <p class="text-xs text-red-500">{getError("modelId", form)}</p>
263
- </select>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  </label>
265
 
266
  <label>
 
9
  import { base } from "$app/paths";
10
  import CarbonPen from "~icons/carbon/pen";
11
  import CarbonUpload from "~icons/carbon/upload";
12
+ import CarbonHelpFilled from "~icons/carbon/help";
13
+ import CarbonSettingsAdjust from "~icons/carbon/settings-adjust";
14
 
15
  import { useSettingsStore } from "$lib/stores/settings";
16
  import { isHuggingChat } from "$lib/utils/isHuggingChat";
17
  import IconInternet from "./icons/IconInternet.svelte";
18
  import TokensCounter from "./TokensCounter.svelte";
19
+ import HoverTooltip from "./HoverTooltip.svelte";
20
 
21
  type ActionData = {
22
  error: boolean;
 
34
 
35
  let files: FileList | null = null;
36
  const settings = useSettingsStore();
37
+ let modelId = "";
 
38
  let systemPrompt = assistant?.preprompt ?? "";
39
  let dynamicPrompt = assistant?.dynamicPrompt ?? false;
40
+ let showModelSettings = Object.values(assistant?.generateSettings ?? {}).some((v) => !!v);
41
 
42
  let compress: typeof readAndCompressImage | null = null;
43
 
44
  onMount(async () => {
45
  const module = await import("browser-image-resizer");
46
  compress = module.readAndCompressImage;
47
+
48
+ if (assistant) {
49
+ modelId = assistant.modelId;
50
+ } else {
51
+ modelId = models.find((model) => model.id === $settings.activeModel)?.id ?? models[0].id;
52
+ }
53
  });
54
 
55
  let inputMessage1 = assistant?.exampleInputs[0] ?? "";
 
98
 
99
  const regex = /{{\s?url=(.+?)\s?}}/g;
100
  $: templateVariables = [...systemPrompt.matchAll(regex)].map((match) => match[1]);
101
+ $: selectedModel = models.find((m) => m.id === modelId);
102
  </script>
103
 
104
  <form
105
  method="POST"
106
+ class="relative flex h-full flex-col overflow-y-auto p-4 md:p-8"
107
  enctype="multipart/form-data"
108
  use:enhance={async ({ formData }) => {
109
  loading = true;
 
256
 
257
  <label>
258
  <div class="mb-1 font-semibold">Model</div>
259
+ <div class="flex gap-2">
260
+ <select
261
+ name="modelId"
262
+ class="w-full rounded-lg border-2 border-gray-200 bg-gray-100 p-2"
263
+ bind:value={modelId}
264
+ >
265
+ {#each models.filter((model) => !model.unlisted) as model}
266
+ <option value={model.id}>{model.displayName}</option>
267
+ {/each}
268
+ <p class="text-xs text-red-500">{getError("modelId", form)}</p>
269
+ </select>
270
+ <button
271
+ type="button"
272
+ class="flex aspect-square items-center gap-2 whitespace-nowrap rounded-lg border px-3 {showModelSettings
273
+ ? 'border-blue-500/20 bg-blue-50 text-blue-600'
274
+ : ''}"
275
+ on:click={() => (showModelSettings = !showModelSettings)}
276
+ ><CarbonSettingsAdjust class="text-xs" /></button
277
+ >
278
+ </div>
279
+ <div
280
+ class="mt-2 rounded-lg border border-blue-500/20 bg-blue-500/5 px-2 py-0.5"
281
+ class:hidden={!showModelSettings}
282
  >
283
+ <p class="text-xs text-red-500">{getError("inputMessage1", form)}</p>
284
+ <div class="my-2 grid grid-cols-1 gap-2.5 sm:grid-cols-2 sm:grid-rows-2">
285
+ <label for="temperature" class="flex justify-between">
286
+ <span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
287
+ Temperature
288
+
289
+ <HoverTooltip
290
+ label="Temperature: Controls creativity, higher values allow more variety."
291
+ >
292
+ <CarbonHelpFilled
293
+ class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
294
+ />
295
+ </HoverTooltip>
296
+ </span>
297
+ <input
298
+ type="number"
299
+ name="temperature"
300
+ min="0.1"
301
+ max="2"
302
+ step="0.1"
303
+ class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
304
+ placeholder={selectedModel?.parameters?.temperature?.toString() ?? "1"}
305
+ value={assistant?.generateSettings?.temperature ?? ""}
306
+ />
307
+ </label>
308
+ <label for="top_p" class="flex justify-between">
309
+ <span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
310
+ Top P
311
+ <HoverTooltip
312
+ label="Top P: Sets word choice boundaries, lower values tighten focus."
313
+ >
314
+ <CarbonHelpFilled
315
+ class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
316
+ />
317
+ </HoverTooltip>
318
+ </span>
319
+
320
+ <input
321
+ type="number"
322
+ name="top_p"
323
+ class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
324
+ min="0.05"
325
+ max="1"
326
+ step="0.05"
327
+ placeholder={selectedModel?.parameters?.top_p?.toString() ?? "1"}
328
+ value={assistant?.generateSettings?.top_p ?? ""}
329
+ />
330
+ </label>
331
+ <label for="repetition_penalty" class="flex justify-between">
332
+ <span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
333
+ Repetition penalty
334
+ <HoverTooltip
335
+ label="Repetition penalty: Prevents reuse, higher values decrease repetition."
336
+ >
337
+ <CarbonHelpFilled
338
+ class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
339
+ />
340
+ </HoverTooltip>
341
+ </span>
342
+ <input
343
+ type="number"
344
+ name="repetition_penalty"
345
+ min="0.1"
346
+ max="2"
347
+ class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
348
+ placeholder={selectedModel?.parameters?.repetition_penalty?.toString() ?? "1.0"}
349
+ value={assistant?.generateSettings?.repetition_penalty ?? ""}
350
+ />
351
+ </label>
352
+ <label for="top_k" class="flex justify-between">
353
+ <span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
354
+ Top K <HoverTooltip
355
+ label="Top K: Restricts word options, lower values for predictability."
356
+ >
357
+ <CarbonHelpFilled
358
+ class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
359
+ />
360
+ </HoverTooltip>
361
+ </span>
362
+ <input
363
+ type="number"
364
+ name="top_k"
365
+ min="5"
366
+ max="100"
367
+ step="5"
368
+ class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
369
+ placeholder={selectedModel?.parameters?.top_k?.toString() ?? "50"}
370
+ value={assistant?.generateSettings?.top_k ?? ""}
371
+ />
372
+ </label>
373
+ </div>
374
+ </div>
375
  </label>
376
 
377
  <label>
src/lib/components/HoverTooltip.svelte ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ export let label = "";
3
+ </script>
4
+
5
+ <div class="group/tooltip md:relative">
6
+ <slot />
7
+ <div
8
+ class="invisible absolute z-10 w-64 whitespace-normal rounded-md bg-black p-2 text-center text-white group-hover/tooltip:visible group-active/tooltip:visible max-sm:left-1/2 max-sm:-translate-x-1/2"
9
+ >
10
+ {label}
11
+ </div>
12
+ </div>
src/lib/server/endpoints/anthropic/endpointAnthropic.ts CHANGED
@@ -32,7 +32,7 @@ export async function endpointAnthropic(
32
  defaultQuery,
33
  });
34
 
35
- return async ({ messages, preprompt }) => {
36
  let system = preprompt;
37
  if (messages?.[0]?.from === "system") {
38
  system = messages[0].content;
@@ -49,15 +49,18 @@ export async function endpointAnthropic(
49
  }[];
50
 
51
  let tokenId = 0;
 
 
 
52
  return (async function* () {
53
  const stream = anthropic.messages.stream({
54
  model: model.id ?? model.name,
55
  messages: messagesFormatted,
56
- max_tokens: model.parameters?.max_new_tokens,
57
- temperature: model.parameters?.temperature,
58
- top_p: model.parameters?.top_p,
59
- top_k: model.parameters?.top_k,
60
- stop_sequences: model.parameters?.stop,
61
  system,
62
  });
63
  while (true) {
 
32
  defaultQuery,
33
  });
34
 
35
+ return async ({ messages, preprompt, generateSettings }) => {
36
  let system = preprompt;
37
  if (messages?.[0]?.from === "system") {
38
  system = messages[0].content;
 
49
  }[];
50
 
51
  let tokenId = 0;
52
+
53
+ const parameters = { ...model.parameters, ...generateSettings };
54
+
55
  return (async function* () {
56
  const stream = anthropic.messages.stream({
57
  model: model.id ?? model.name,
58
  messages: messagesFormatted,
59
+ max_tokens: parameters?.max_new_tokens,
60
+ temperature: parameters?.temperature,
61
+ top_p: parameters?.top_p,
62
+ top_k: parameters?.top_k,
63
+ stop_sequences: parameters?.stop,
64
  system,
65
  });
66
  while (true) {
src/lib/server/endpoints/aws/endpointAws.ts CHANGED
@@ -36,7 +36,7 @@ export async function endpointAws(
36
  region,
37
  });
38
 
39
- return async ({ messages, preprompt, continueMessage }) => {
40
  const prompt = await buildPrompt({
41
  messages,
42
  continueMessage,
@@ -46,7 +46,7 @@ export async function endpointAws(
46
 
47
  return textGenerationStream(
48
  {
49
- parameters: { ...model.parameters, return_full_text: false },
50
  model: url,
51
  inputs: prompt,
52
  },
 
36
  region,
37
  });
38
 
39
+ return async ({ messages, preprompt, continueMessage, generateSettings }) => {
40
  const prompt = await buildPrompt({
41
  messages,
42
  continueMessage,
 
46
 
47
  return textGenerationStream(
48
  {
49
+ parameters: { ...model.parameters, ...generateSettings, return_full_text: false },
50
  model: url,
51
  inputs: prompt,
52
  },
src/lib/server/endpoints/endpoints.ts CHANGED
@@ -10,12 +10,14 @@ import {
10
  endpointAnthropic,
11
  endpointAnthropicParametersSchema,
12
  } from "./anthropic/endpointAnthropic";
 
13
 
14
  // parameters passed when generating text
15
  export interface EndpointParameters {
16
  messages: Omit<Conversation["messages"][0], "id">[];
17
  preprompt?: Conversation["preprompt"];
18
  continueMessage?: boolean; // used to signal that the last message will be extended
 
19
  }
20
 
21
  interface CommonEndpoint {
 
10
  endpointAnthropic,
11
  endpointAnthropicParametersSchema,
12
  } from "./anthropic/endpointAnthropic";
13
+ import type { Model } from "$lib/types/Model";
14
 
15
  // parameters passed when generating text
16
  export interface EndpointParameters {
17
  messages: Omit<Conversation["messages"][0], "id">[];
18
  preprompt?: Conversation["preprompt"];
19
  continueMessage?: boolean; // used to signal that the last message will be extended
20
+ generateSettings?: Partial<Model["parameters"]>;
21
  }
22
 
23
  interface CommonEndpoint {
src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts CHANGED
@@ -19,7 +19,7 @@ export function endpointLlamacpp(
19
  input: z.input<typeof endpointLlamacppParametersSchema>
20
  ): Endpoint {
21
  const { url, model } = endpointLlamacppParametersSchema.parse(input);
22
- return async ({ messages, preprompt, continueMessage }) => {
23
  const prompt = await buildPrompt({
24
  messages,
25
  continueMessage,
@@ -27,6 +27,8 @@ export function endpointLlamacpp(
27
  model,
28
  });
29
 
 
 
30
  const r = await fetch(`${url}/completion`, {
31
  method: "POST",
32
  headers: {
@@ -35,12 +37,12 @@ export function endpointLlamacpp(
35
  body: JSON.stringify({
36
  prompt,
37
  stream: true,
38
- temperature: model.parameters.temperature,
39
- top_p: model.parameters.top_p,
40
- top_k: model.parameters.top_k,
41
- stop: model.parameters.stop,
42
- repeat_penalty: model.parameters.repetition_penalty,
43
- n_predict: model.parameters.max_new_tokens,
44
  cache_prompt: true,
45
  }),
46
  });
 
19
  input: z.input<typeof endpointLlamacppParametersSchema>
20
  ): Endpoint {
21
  const { url, model } = endpointLlamacppParametersSchema.parse(input);
22
+ return async ({ messages, preprompt, continueMessage, generateSettings }) => {
23
  const prompt = await buildPrompt({
24
  messages,
25
  continueMessage,
 
27
  model,
28
  });
29
 
30
+ const parameters = { ...model.parameters, ...generateSettings };
31
+
32
  const r = await fetch(`${url}/completion`, {
33
  method: "POST",
34
  headers: {
 
37
  body: JSON.stringify({
38
  prompt,
39
  stream: true,
40
+ temperature: parameters.temperature,
41
+ top_p: parameters.top_p,
42
+ top_k: parameters.top_k,
43
+ stop: parameters.stop,
44
+ repeat_penalty: parameters.repetition_penalty,
45
+ n_predict: parameters.max_new_tokens,
46
  cache_prompt: true,
47
  }),
48
  });
src/lib/server/endpoints/ollama/endpointOllama.ts CHANGED
@@ -14,7 +14,7 @@ export const endpointOllamaParametersSchema = z.object({
14
  export function endpointOllama(input: z.input<typeof endpointOllamaParametersSchema>): Endpoint {
15
  const { url, model, ollamaName } = endpointOllamaParametersSchema.parse(input);
16
 
17
- return async ({ messages, preprompt, continueMessage }) => {
18
  const prompt = await buildPrompt({
19
  messages,
20
  continueMessage,
@@ -22,6 +22,8 @@ export function endpointOllama(input: z.input<typeof endpointOllamaParametersSch
22
  model,
23
  });
24
 
 
 
25
  const r = await fetch(`${url}/api/generate`, {
26
  method: "POST",
27
  headers: {
@@ -32,12 +34,12 @@ export function endpointOllama(input: z.input<typeof endpointOllamaParametersSch
32
  model: ollamaName ?? model.name,
33
  raw: true,
34
  options: {
35
- top_p: model.parameters.top_p,
36
- top_k: model.parameters.top_k,
37
- temperature: model.parameters.temperature,
38
- repeat_penalty: model.parameters.repetition_penalty,
39
- stop: model.parameters.stop,
40
- num_predict: model.parameters.max_new_tokens,
41
  },
42
  }),
43
  });
 
14
  export function endpointOllama(input: z.input<typeof endpointOllamaParametersSchema>): Endpoint {
15
  const { url, model, ollamaName } = endpointOllamaParametersSchema.parse(input);
16
 
17
+ return async ({ messages, preprompt, continueMessage, generateSettings }) => {
18
  const prompt = await buildPrompt({
19
  messages,
20
  continueMessage,
 
22
  model,
23
  });
24
 
25
+ const parameters = { ...model.parameters, ...generateSettings };
26
+
27
  const r = await fetch(`${url}/api/generate`, {
28
  method: "POST",
29
  headers: {
 
34
  model: ollamaName ?? model.name,
35
  raw: true,
36
  options: {
37
+ top_p: parameters.top_p,
38
+ top_k: parameters.top_k,
39
+ temperature: parameters.temperature,
40
+ repeat_penalty: parameters.repetition_penalty,
41
+ stop: parameters.stop,
42
+ num_predict: parameters.max_new_tokens,
43
  },
44
  }),
45
  });
src/lib/server/endpoints/openai/endpointOai.ts CHANGED
@@ -38,7 +38,7 @@ export async function endpointOai(
38
  });
39
 
40
  if (completion === "completions") {
41
- return async ({ messages, preprompt, continueMessage }) => {
42
  const prompt = await buildPrompt({
43
  messages,
44
  continueMessage,
@@ -46,21 +46,23 @@ export async function endpointOai(
46
  model,
47
  });
48
 
 
 
49
  return openAICompletionToTextGenerationStream(
50
  await openai.completions.create({
51
  model: model.id ?? model.name,
52
  prompt,
53
  stream: true,
54
- max_tokens: model.parameters?.max_new_tokens,
55
- stop: model.parameters?.stop,
56
- temperature: model.parameters?.temperature,
57
- top_p: model.parameters?.top_p,
58
- frequency_penalty: model.parameters?.repetition_penalty,
59
  })
60
  );
61
  };
62
  } else if (completion === "chat_completions") {
63
- return async ({ messages, preprompt }) => {
64
  let messagesOpenAI = messages.map((message) => ({
65
  role: message.from,
66
  content: message.content,
@@ -74,16 +76,18 @@ export async function endpointOai(
74
  messagesOpenAI[0].content = preprompt ?? "";
75
  }
76
 
 
 
77
  return openAIChatToTextGenerationStream(
78
  await openai.chat.completions.create({
79
  model: model.id ?? model.name,
80
  messages: messagesOpenAI,
81
  stream: true,
82
- max_tokens: model.parameters?.max_new_tokens,
83
- stop: model.parameters?.stop,
84
- temperature: model.parameters?.temperature,
85
- top_p: model.parameters?.top_p,
86
- frequency_penalty: model.parameters?.repetition_penalty,
87
  })
88
  );
89
  };
 
38
  });
39
 
40
  if (completion === "completions") {
41
+ return async ({ messages, preprompt, continueMessage, generateSettings }) => {
42
  const prompt = await buildPrompt({
43
  messages,
44
  continueMessage,
 
46
  model,
47
  });
48
 
49
+ const parameters = { ...model.parameters, ...generateSettings };
50
+
51
  return openAICompletionToTextGenerationStream(
52
  await openai.completions.create({
53
  model: model.id ?? model.name,
54
  prompt,
55
  stream: true,
56
+ max_tokens: parameters?.max_new_tokens,
57
+ stop: parameters?.stop,
58
+ temperature: parameters?.temperature,
59
+ top_p: parameters?.top_p,
60
+ frequency_penalty: parameters?.repetition_penalty,
61
  })
62
  );
63
  };
64
  } else if (completion === "chat_completions") {
65
+ return async ({ messages, preprompt, generateSettings }) => {
66
  let messagesOpenAI = messages.map((message) => ({
67
  role: message.from,
68
  content: message.content,
 
76
  messagesOpenAI[0].content = preprompt ?? "";
77
  }
78
 
79
+ const parameters = { ...model.parameters, ...generateSettings };
80
+
81
  return openAIChatToTextGenerationStream(
82
  await openai.chat.completions.create({
83
  model: model.id ?? model.name,
84
  messages: messagesOpenAI,
85
  stream: true,
86
+ max_tokens: parameters?.max_new_tokens,
87
+ stop: parameters?.stop,
88
+ temperature: parameters?.temperature,
89
+ top_p: parameters?.top_p,
90
+ frequency_penalty: parameters?.repetition_penalty,
91
  })
92
  );
93
  };
src/lib/server/endpoints/tgi/endpointTgi.ts CHANGED
@@ -16,7 +16,7 @@ export const endpointTgiParametersSchema = z.object({
16
  export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
17
  const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
18
 
19
- return async ({ messages, preprompt, continueMessage }) => {
20
  const prompt = await buildPrompt({
21
  messages,
22
  preprompt,
@@ -26,7 +26,7 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
26
 
27
  return textGenerationStream(
28
  {
29
- parameters: { ...model.parameters, return_full_text: false },
30
  model: url,
31
  inputs: prompt,
32
  accessToken,
 
16
  export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
17
  const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
18
 
19
+ return async ({ messages, preprompt, continueMessage, generateSettings }) => {
20
  const prompt = await buildPrompt({
21
  messages,
22
  preprompt,
 
26
 
27
  return textGenerationStream(
28
  {
29
+ parameters: { ...model.parameters, ...generateSettings, return_full_text: false },
30
  model: url,
31
  inputs: prompt,
32
  accessToken,
src/lib/types/Assistant.ts CHANGED
@@ -19,6 +19,12 @@ export interface Assistant extends Timestamps {
19
  allowedDomains: string[];
20
  allowedLinks: string[];
21
  };
 
 
 
 
 
 
22
  dynamicPrompt?: boolean;
23
  searchTokens: string[];
24
  }
 
19
  allowedDomains: string[];
20
  allowedLinks: string[];
21
  };
22
+ generateSettings?: {
23
+ temperature?: number;
24
+ top_p?: number;
25
+ repetition_penalty?: number;
26
+ top_k?: number;
27
+ };
28
  dynamicPrompt?: boolean;
29
  searchTokens: string[];
30
  }
src/routes/conversation/[id]/+server.ts CHANGED
@@ -338,8 +338,11 @@ export async function POST({ request, locals, params, getClientAddress }) {
338
 
339
  // check if assistant has a rag
340
  const assistant = await collections.assistants.findOne<
341
- Pick<Assistant, "rag" | "dynamicPrompt">
342
- >({ _id: conv.assistantId }, { projection: { rag: 1, dynamicPrompt: 1 } });
 
 
 
343
 
344
  const assistantHasRAG =
345
  ENABLE_ASSISTANTS_RAG === "true" &&
@@ -403,12 +406,15 @@ export async function POST({ request, locals, params, getClientAddress }) {
403
 
404
  const previousText = messageToWriteTo.content;
405
 
 
 
406
  try {
407
  const endpoint = await model.getEndpoint();
408
  for await (const output of await endpoint({
409
  messages: processedMessages,
410
  preprompt,
411
  continueMessage: isContinue,
 
412
  })) {
413
  // if not generated_text is here it means the generation is not done
414
  if (!output.generated_text) {
@@ -448,10 +454,11 @@ export async function POST({ request, locals, params, getClientAddress }) {
448
  }
449
  }
450
  } catch (e) {
 
451
  update({ type: "status", status: "error", message: (e as Error).message });
452
  } finally {
453
  // check if no output was generated
454
- if (messageToWriteTo.content === previousText) {
455
  update({
456
  type: "status",
457
  status: "error",
 
338
 
339
  // check if assistant has a rag
340
  const assistant = await collections.assistants.findOne<
341
+ Pick<Assistant, "rag" | "dynamicPrompt" | "generateSettings">
342
+ >(
343
+ { _id: conv.assistantId },
344
+ { projection: { rag: 1, dynamicPrompt: 1, generateSettings: 1 } }
345
+ );
346
 
347
  const assistantHasRAG =
348
  ENABLE_ASSISTANTS_RAG === "true" &&
 
406
 
407
  const previousText = messageToWriteTo.content;
408
 
409
+ let hasError = false;
410
+
411
  try {
412
  const endpoint = await model.getEndpoint();
413
  for await (const output of await endpoint({
414
  messages: processedMessages,
415
  preprompt,
416
  continueMessage: isContinue,
417
+ generateSettings: assistant?.generateSettings,
418
  })) {
419
  // if not generated_text is here it means the generation is not done
420
  if (!output.generated_text) {
 
454
  }
455
  }
456
  } catch (e) {
457
+ hasError = true;
458
  update({ type: "status", status: "error", message: (e as Error).message });
459
  } finally {
460
  // check if no output was generated
461
+ if (!hasError && messageToWriteTo.content === previousText) {
462
  update({
463
  type: "status",
464
  status: "error",
src/routes/settings/(nav)/assistants/[assistantId]/edit/+page.server.ts CHANGED
@@ -25,6 +25,20 @@ const newAsssistantSchema = z.object({
25
  ragDomainList: z.preprocess(parseStringToList, z.string().array()),
26
  ragAllowAll: z.preprocess((v) => v === "true", z.boolean()),
27
  dynamicPrompt: z.preprocess((v) => v === "on", z.boolean()),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  });
29
 
30
  const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
@@ -143,6 +157,12 @@ export const actions: Actions = {
143
  },
144
  dynamicPrompt: parse.data.dynamicPrompt,
145
  searchTokens: generateSearchTokens(parse.data.name),
 
 
 
 
 
 
146
  },
147
  }
148
  );
 
25
  ragDomainList: z.preprocess(parseStringToList, z.string().array()),
26
  ragAllowAll: z.preprocess((v) => v === "true", z.boolean()),
27
  dynamicPrompt: z.preprocess((v) => v === "on", z.boolean()),
28
+ temperature: z
29
+ .union([z.literal(""), z.coerce.number().min(0.1).max(2)])
30
+ .transform((v) => (v === "" ? undefined : v)),
31
+ top_p: z
32
+ .union([z.literal(""), z.coerce.number().min(0.05).max(1)])
33
+ .transform((v) => (v === "" ? undefined : v)),
34
+
35
+ repetition_penalty: z
36
+ .union([z.literal(""), z.coerce.number().min(0.1).max(2)])
37
+ .transform((v) => (v === "" ? undefined : v)),
38
+
39
+ top_k: z
40
+ .union([z.literal(""), z.coerce.number().min(5).max(100)])
41
+ .transform((v) => (v === "" ? undefined : v)),
42
  });
43
 
44
  const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
 
157
  },
158
  dynamicPrompt: parse.data.dynamicPrompt,
159
  searchTokens: generateSearchTokens(parse.data.name),
160
+ generateSettings: {
161
+ temperature: parse.data.temperature,
162
+ top_p: parse.data.top_p,
163
+ repetition_penalty: parse.data.repetition_penalty,
164
+ top_k: parse.data.top_k,
165
+ },
166
  },
167
  }
168
  );
src/routes/settings/(nav)/assistants/new/+page.server.ts CHANGED
@@ -25,6 +25,20 @@ const newAsssistantSchema = z.object({
25
  ragDomainList: z.preprocess(parseStringToList, z.string().array()),
26
  ragAllowAll: z.preprocess((v) => v === "true", z.boolean()),
27
  dynamicPrompt: z.preprocess((v) => v === "on", z.boolean()),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  });
29
 
30
  const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
@@ -125,6 +139,12 @@ export const actions: Actions = {
125
  },
126
  dynamicPrompt: parse.data.dynamicPrompt,
127
  searchTokens: generateSearchTokens(parse.data.name),
 
 
 
 
 
 
128
  });
129
 
130
  // add insertedId to user settings
 
25
  ragDomainList: z.preprocess(parseStringToList, z.string().array()),
26
  ragAllowAll: z.preprocess((v) => v === "true", z.boolean()),
27
  dynamicPrompt: z.preprocess((v) => v === "on", z.boolean()),
28
+ temperature: z
29
+ .union([z.literal(""), z.coerce.number().min(0.1).max(2)])
30
+ .transform((v) => (v === "" ? undefined : v)),
31
+ top_p: z
32
+ .union([z.literal(""), z.coerce.number().min(0.05).max(1)])
33
+ .transform((v) => (v === "" ? undefined : v)),
34
+
35
+ repetition_penalty: z
36
+ .union([z.literal(""), z.coerce.number().min(0.1).max(2)])
37
+ .transform((v) => (v === "" ? undefined : v)),
38
+
39
+ top_k: z
40
+ .union([z.literal(""), z.coerce.number().min(5).max(100)])
41
+ .transform((v) => (v === "" ? undefined : v)),
42
  });
43
 
44
  const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
 
139
  },
140
  dynamicPrompt: parse.data.dynamicPrompt,
141
  searchTokens: generateSearchTokens(parse.data.name),
142
+ generateSettings: {
143
+ temperature: parse.data.temperature,
144
+ top_p: parse.data.top_p,
145
+ repetition_penalty: parse.data.repetition_penalty,
146
+ top_k: parse.data.top_k,
147
+ },
148
  });
149
 
150
  // add insertedId to user settings