nsarrazin HF staff commited on
Commit
cd6894d
1 Parent(s): 447c0ca

Support custom system prompts from the user (#399)

Browse files

* Support custom system prompts from the user

* linter

* types & lint

src/lib/buildPrompt.ts CHANGED
@@ -11,7 +11,8 @@ import { ObjectId } from "mongodb";
11
  export async function buildPrompt(
12
  messages: Pick<Message, "from" | "content">[],
13
  model: BackendModel,
14
- webSearchId?: string
 
15
  ): Promise<string> {
16
  if (webSearchId) {
17
  const webSearch = await collections.webSearches.findOne({
@@ -33,7 +34,7 @@ export async function buildPrompt(
33
 
34
  return (
35
  model
36
- .chatPromptRender({ messages })
37
  // Not super precise, but it's truncated in the model's backend anyway
38
  .split(" ")
39
  .slice(-(model.parameters?.truncate ?? 0))
 
11
  export async function buildPrompt(
12
  messages: Pick<Message, "from" | "content">[],
13
  model: BackendModel,
14
+ webSearchId?: string,
15
+ preprompt?: string
16
  ): Promise<string> {
17
  if (webSearchId) {
18
  const webSearch = await collections.webSearches.findOne({
 
34
 
35
  return (
36
  model
37
+ .chatPromptRender({ messages, preprompt })
38
  // Not super precise, but it's truncated in the model's backend anyway
39
  .split(" ")
40
  .slice(-(model.parameters?.truncate ?? 0))
src/lib/components/ModelsModal.svelte CHANGED
@@ -10,26 +10,56 @@
10
  import { enhance } from "$app/forms";
11
  import { base } from "$app/paths";
12
 
 
 
 
 
13
  export let settings: LayoutData["settings"];
14
  export let models: Array<Model>;
15
 
16
  let selectedModelId = settings.activeModel;
17
 
18
  const dispatch = createEventDispatcher<{ close: void }>();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  </script>
20
 
21
  <Modal width="max-w-lg" on:close>
22
  <form
23
  action="{base}/settings"
24
  method="post"
 
 
 
 
 
25
  use:enhance={() => {
26
  dispatch("close");
27
  }}
28
  class="flex w-full flex-col gap-5 p-6"
29
  >
30
- {#each Object.entries(settings).filter(([k]) => k !== "activeModel") as [key, val]}
31
  <input type="hidden" name={key} value={val} />
32
  {/each}
 
33
  <div class="flex items-start justify-between text-xl font-semibold text-gray-800">
34
  <h2>Models</h2>
35
  <button type="button" class="group" on:click={() => dispatch("close")}>
@@ -39,8 +69,9 @@
39
 
40
  <div class="space-y-4">
41
  {#each models as model}
 
42
  <div
43
- class="rounded-xl border border-gray-100 {model.id === selectedModelId
44
  ? 'bg-gradient-to-r from-primary-200/40 via-primary-500/10'
45
  : ''}"
46
  >
@@ -61,11 +92,49 @@
61
  {/if}
62
  </span>
63
  <CarbonCheckmark
64
- class="-mr-1 -mt-1 ml-auto shrink-0 text-xl {model.id === selectedModelId
65
  ? 'text-primary-400'
66
  : 'text-transparent group-hover:text-gray-200'}"
67
  />
68
  </label>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  <ModelCardMetadata {model} />
70
  </div>
71
  {/each}
 
10
  import { enhance } from "$app/forms";
11
  import { base } from "$app/paths";
12
 
13
+ import CarbonEdit from "~icons/carbon/edit";
14
+ import CarbonSave from "~icons/carbon/save";
15
+ import CarbonRestart from "~icons/carbon/restart";
16
+
17
  export let settings: LayoutData["settings"];
18
  export let models: Array<Model>;
19
 
20
  let selectedModelId = settings.activeModel;
21
 
22
  const dispatch = createEventDispatcher<{ close: void }>();
23
+
24
+ let expanded = false;
25
+
26
+ function onToggle() {
27
+ if (expanded) {
28
+ settings.customPrompts[selectedModelId] = value;
29
+ }
30
+ expanded = !expanded;
31
+ }
32
+
33
+ let value = "";
34
+
35
+ function onModelChange() {
36
+ value =
37
+ settings.customPrompts[selectedModelId] ??
38
+ models.filter((el) => el.id === selectedModelId)[0].preprompt ??
39
+ "";
40
+ }
41
+
42
+ $: selectedModelId, onModelChange();
43
  </script>
44
 
45
  <Modal width="max-w-lg" on:close>
46
  <form
47
  action="{base}/settings"
48
  method="post"
49
+ on:submit={() => {
50
+ if (expanded) {
51
+ onToggle();
52
+ }
53
+ }}
54
  use:enhance={() => {
55
  dispatch("close");
56
  }}
57
  class="flex w-full flex-col gap-5 p-6"
58
  >
59
+ {#each Object.entries(settings).filter(([k]) => !(k == "activeModel" || k === "customPrompts")) as [key, val]}
60
  <input type="hidden" name={key} value={val} />
61
  {/each}
62
+ <input type="hidden" name="customPrompts" value={JSON.stringify(settings.customPrompts)} />
63
  <div class="flex items-start justify-between text-xl font-semibold text-gray-800">
64
  <h2>Models</h2>
65
  <button type="button" class="group" on:click={() => dispatch("close")}>
 
69
 
70
  <div class="space-y-4">
71
  {#each models as model}
72
+ {@const active = model.id === selectedModelId}
73
  <div
74
+ class="rounded-xl border border-gray-100 {active
75
  ? 'bg-gradient-to-r from-primary-200/40 via-primary-500/10'
76
  : ''}"
77
  >
 
92
  {/if}
93
  </span>
94
  <CarbonCheckmark
95
+ class="-mr-1 -mt-1 ml-auto shrink-0 text-xl {active
96
  ? 'text-primary-400'
97
  : 'text-transparent group-hover:text-gray-200'}"
98
  />
99
  </label>
100
+ {#if active}
101
+ <div class=" overflow-hidden rounded-xl px-3 pb-2">
102
+ <div class="flex flex-row flex-nowrap gap-2 pb-1">
103
+ <div class="text-xs font-semibold text-gray-500">System Prompt</div>
104
+ {#if expanded}
105
+ <button
106
+ class="text-gray-500 hover:text-gray-900"
107
+ on:click|preventDefault={onToggle}
108
+ >
109
+ <CarbonSave class="text-sm " />
110
+ </button>
111
+ <button
112
+ class="text-gray-500 hover:text-gray-900"
113
+ on:click|preventDefault={() => {
114
+ value = model.preprompt ?? "";
115
+ }}
116
+ >
117
+ <CarbonRestart class="text-sm " />
118
+ </button>
119
+ {:else}
120
+ <button
121
+ class=" text-gray-500 hover:text-gray-900"
122
+ on:click|preventDefault={onToggle}
123
+ >
124
+ <CarbonEdit class="text-sm " />
125
+ </button>
126
+ {/if}
127
+ </div>
128
+ <textarea
129
+ enterkeyhint="send"
130
+ tabindex="0"
131
+ rows="1"
132
+ class="h-20 w-full resize-none scroll-p-3 overflow-x-hidden overflow-y-scroll rounded-md border border-gray-300 bg-transparent p-1 text-xs outline-none focus:ring-0 focus-visible:ring-0"
133
+ bind:value
134
+ hidden={!expanded}
135
+ />
136
+ </div>
137
+ {/if}
138
  <ModelCardMetadata {model} />
139
  </div>
140
  {/each}
src/lib/components/chat/ChatIntroduction.svelte CHANGED
@@ -78,7 +78,7 @@
78
  </div>
79
  </div>
80
  {#if currentModelMetadata.promptExamples}
81
- <div class="lg:col-span-3 lg:mt-12">
82
  <p class="mb-3 text-gray-600 dark:text-gray-300">Examples</p>
83
  <div class="grid gap-3 lg:grid-cols-3 lg:gap-5">
84
  {#each currentModelMetadata.promptExamples as example}
 
78
  </div>
79
  </div>
80
  {#if currentModelMetadata.promptExamples}
81
+ <div class="lg:col-span-3 lg:mt-6">
82
  <p class="mb-3 text-gray-600 dark:text-gray-300">Examples</p>
83
  <div class="grid gap-3 lg:grid-cols-3 lg:gap-5">
84
  {#each currentModelMetadata.promptExamples as example}
src/lib/server/models.ts CHANGED
@@ -7,6 +7,8 @@ import type {
7
  import { compileTemplate } from "$lib/utils/template";
8
  import { z } from "zod";
9
 
 
 
10
  const sagemakerEndpoint = z.object({
11
  host: z.literal("sagemaker"),
12
  url: z.string().url(),
@@ -57,7 +59,7 @@ const modelsRaw = z
57
  assistantMessageToken: z.string().default(""),
58
  assistantMessageEndToken: z.string().default(""),
59
  messageEndToken: z.string().default(""),
60
- preprompt: z.string().default(""),
61
  prepromptUrl: z.string().url().optional(),
62
  chatPromptTemplate: z
63
  .string()
@@ -148,7 +150,7 @@ export const oldModels = OLD_MODELS
148
  .map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
149
  : [];
150
 
151
- export type BackendModel = (typeof models)[0];
152
  export type Endpoint = z.infer<typeof endpoint>;
153
 
154
  export const defaultModel = models[0];
 
7
  import { compileTemplate } from "$lib/utils/template";
8
  import { z } from "zod";
9
 
10
+ type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;
11
+
12
  const sagemakerEndpoint = z.object({
13
  host: z.literal("sagemaker"),
14
  url: z.string().url(),
 
59
  assistantMessageToken: z.string().default(""),
60
  assistantMessageEndToken: z.string().default(""),
61
  messageEndToken: z.string().default(""),
62
+ preprompt: z.string().min(1).optional(),
63
  prepromptUrl: z.string().url().optional(),
64
  chatPromptTemplate: z
65
  .string()
 
150
  .map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
151
  : [];
152
 
153
+ export type BackendModel = Optional<(typeof models)[0], "preprompt">;
154
  export type Endpoint = z.infer<typeof endpoint>;
155
 
156
  export const defaultModel = models[0];
src/lib/types/Model.ts CHANGED
@@ -12,4 +12,5 @@ export type Model = Pick<
12
  | "description"
13
  | "modelUrl"
14
  | "datasetUrl"
 
15
  >;
 
12
  | "description"
13
  | "modelUrl"
14
  | "datasetUrl"
15
+ | "preprompt"
16
  >;
src/lib/types/Settings.ts CHANGED
@@ -14,6 +14,9 @@ export interface Settings extends Timestamps {
14
  shareConversationsWithModelAuthors: boolean;
15
  ethicsModalAcceptedAt: Date | null;
16
  activeModel: string;
 
 
 
17
  }
18
 
19
  // TODO: move this to a constant file along with other constants
 
14
  shareConversationsWithModelAuthors: boolean;
15
  ethicsModalAcceptedAt: Date | null;
16
  activeModel: string;
17
+
18
+ // model name and system prompts
19
+ customPrompts?: Record<string, string>;
20
  }
21
 
22
  // TODO: move this to a constant file along with other constants
src/lib/types/Template.ts CHANGED
@@ -1,7 +1,7 @@
1
  import type { Message } from "./Message";
2
 
3
  export type LegacyParamatersTemplateInput = {
4
- preprompt: string;
5
  userMessageToken: string;
6
  userMessageEndToken: string;
7
  assistantMessageToken: string;
@@ -10,6 +10,7 @@ export type LegacyParamatersTemplateInput = {
10
 
11
  export type ChatTemplateInput = {
12
  messages: Pick<Message, "from" | "content">[];
 
13
  };
14
 
15
  export type WebSearchSummaryTemplateInput = {
 
1
  import type { Message } from "./Message";
2
 
3
  export type LegacyParamatersTemplateInput = {
4
+ preprompt?: string;
5
  userMessageToken: string;
6
  userMessageEndToken: string;
7
  assistantMessageToken: string;
 
10
 
11
  export type ChatTemplateInput = {
12
  messages: Pick<Message, "from" | "content">[];
13
+ preprompt?: string;
14
  };
15
 
16
  export type WebSearchSummaryTemplateInput = {
src/routes/+layout.server.ts CHANGED
@@ -62,6 +62,7 @@ export const load: LayoutServerLoad = async ({ locals, depends, url }) => {
62
  ethicsModalAcceptedAt: settings?.ethicsModalAcceptedAt ?? null,
63
  activeModel: settings?.activeModel ?? DEFAULT_SETTINGS.activeModel,
64
  searchEnabled: !!(SERPAPI_KEY || SERPER_API_KEY),
 
65
  },
66
  models: models.map((model) => ({
67
  id: model.id,
@@ -74,6 +75,7 @@ export const load: LayoutServerLoad = async ({ locals, depends, url }) => {
74
  description: model.description,
75
  promptExamples: model.promptExamples,
76
  parameters: model.parameters,
 
77
  })),
78
  oldModels,
79
  user: locals.user && {
 
62
  ethicsModalAcceptedAt: settings?.ethicsModalAcceptedAt ?? null,
63
  activeModel: settings?.activeModel ?? DEFAULT_SETTINGS.activeModel,
64
  searchEnabled: !!(SERPAPI_KEY || SERPER_API_KEY),
65
+ customPrompts: settings?.customPrompts ?? {},
66
  },
67
  models: models.map((model) => ({
68
  id: model.id,
 
75
  description: model.description,
76
  promptExamples: model.promptExamples,
77
  parameters: model.parameters,
78
+ preprompt: model.preprompt,
79
  })),
80
  oldModels,
81
  user: locals.user && {
src/routes/conversation/[id]/+server.ts CHANGED
@@ -53,6 +53,7 @@ export async function POST({ request, fetch, locals, params }) {
53
  }
54
 
55
  const model = models.find((m) => m.id === conv.model);
 
56
 
57
  if (!model) {
58
  throw error(410, "Model not available anymore");
@@ -97,7 +98,13 @@ export async function POST({ request, fetch, locals, params }) {
97
  ];
98
  })() satisfies Message[];
99
 
100
- const prompt = await buildPrompt(messages, model, web_search_id);
 
 
 
 
 
 
101
  const randomEndpoint = modelEndpoint(model);
102
 
103
  const abortController = new AbortController();
 
53
  }
54
 
55
  const model = models.find((m) => m.id === conv.model);
56
+ const settings = await collections.settings.findOne(authCondition(locals));
57
 
58
  if (!model) {
59
  throw error(410, "Model not available anymore");
 
98
  ];
99
  })() satisfies Message[];
100
 
101
+ const prompt = await buildPrompt(
102
+ messages,
103
+ model,
104
+ web_search_id,
105
+ settings?.customPrompts?.[model.id]
106
+ );
107
+
108
  const randomEndpoint = modelEndpoint(model);
109
 
110
  const abortController = new AbortController();
src/routes/settings/+page.server.ts CHANGED
@@ -17,11 +17,13 @@ export const actions = {
17
  .default(DEFAULT_SETTINGS.shareConversationsWithModelAuthors),
18
  ethicsModalAccepted: z.boolean({ coerce: true }).optional(),
19
  activeModel: validateModel(models),
 
20
  })
21
  .parse({
22
  shareConversationsWithModelAuthors: formData.get("shareConversationsWithModelAuthors"),
23
  ethicsModalAccepted: formData.get("ethicsModalAccepted"),
24
  activeModel: formData.get("activeModel") ?? DEFAULT_SETTINGS.activeModel,
 
25
  });
26
 
27
  await collections.settings.updateOne(
@@ -40,7 +42,6 @@ export const actions = {
40
  upsert: true,
41
  }
42
  );
43
-
44
  throw redirect(303, request.headers.get("referer") || `${base}/`);
45
  },
46
  };
 
17
  .default(DEFAULT_SETTINGS.shareConversationsWithModelAuthors),
18
  ethicsModalAccepted: z.boolean({ coerce: true }).optional(),
19
  activeModel: validateModel(models),
20
+ customPrompts: z.record(z.string()).default({}),
21
  })
22
  .parse({
23
  shareConversationsWithModelAuthors: formData.get("shareConversationsWithModelAuthors"),
24
  ethicsModalAccepted: formData.get("ethicsModalAccepted"),
25
  activeModel: formData.get("activeModel") ?? DEFAULT_SETTINGS.activeModel,
26
+ customPrompts: JSON.parse(formData.get("customPrompts")?.toString() ?? "{}"),
27
  });
28
 
29
  await collections.settings.updateOne(
 
42
  upsert: true,
43
  }
44
  );
 
45
  throw redirect(303, request.headers.get("referer") || `${base}/`);
46
  },
47
  };