coyotte508 HF staff commited on
Commit
cf7ac8d
·
unverified ·
1 Parent(s): 97dc766

✨ Add model id to be able to switch models while keeping conversations valid (#181)

Browse files
src/lib/components/ModelsModal.svelte CHANGED
@@ -13,7 +13,7 @@
13
  export let settings: LayoutData["settings"];
14
  export let models: Array<Model>;
15
 
16
- let selectedModelName = settings.activeModel;
17
 
18
  const dispatch = createEventDispatcher<{ close: void }>();
19
  </script>
@@ -40,7 +40,7 @@
40
  <div class="space-y-4">
41
  {#each models as model}
42
  <div
43
- class="rounded-xl border border-gray-100 {model.name === selectedModelName
44
  ? 'bg-gradient-to-r from-yellow-200/40 via-yellow-500/10'
45
  : ''}"
46
  >
@@ -49,8 +49,8 @@
49
  type="radio"
50
  class="sr-only"
51
  name="activeModel"
52
- value={model.name}
53
- bind:group={selectedModelName}
54
  />
55
  <span>
56
  <span class="text-md block font-semibold leading-tight text-gray-800"
@@ -61,7 +61,7 @@
61
  {/if}
62
  </span>
63
  <CarbonCheckmark
64
- class="-mr-1 -mt-1 ml-auto shrink-0 text-xl {model.name === selectedModelName
65
  ? 'text-yellow-400'
66
  : 'text-transparent group-hover:text-gray-200'}"
67
  />
 
13
  export let settings: LayoutData["settings"];
14
  export let models: Array<Model>;
15
 
16
+ let selectedModelId = settings.activeModel;
17
 
18
  const dispatch = createEventDispatcher<{ close: void }>();
19
  </script>
 
40
  <div class="space-y-4">
41
  {#each models as model}
42
  <div
43
+ class="rounded-xl border border-gray-100 {model.id === selectedModelId
44
  ? 'bg-gradient-to-r from-yellow-200/40 via-yellow-500/10'
45
  : ''}"
46
  >
 
49
  type="radio"
50
  class="sr-only"
51
  name="activeModel"
52
+ value={model.id}
53
+ bind:group={selectedModelId}
54
  />
55
  <span>
56
  <span class="text-md block font-semibold leading-tight text-gray-800"
 
61
  {/if}
62
  </span>
63
  <CarbonCheckmark
64
+ class="-mr-1 -mt-1 ml-auto shrink-0 text-xl {model.id === selectedModelId
65
  ? 'text-yellow-400'
66
  : 'text-transparent group-hover:text-gray-200'}"
67
  />
src/lib/server/modelEndpoint.ts CHANGED
@@ -1,34 +1,23 @@
1
  import { HF_ACCESS_TOKEN } from "$env/static/private";
2
  import { sum } from "$lib/utils/sum";
3
- import { models } from "./models";
4
 
5
  /**
6
  * Find a random load-balanced endpoint
7
  */
8
- export function modelEndpoint(model: string): {
9
  url: string;
10
  authorization: string;
11
  weight: number;
12
  } {
13
- const modelDefinition = models.find((m) => m.name === model);
14
- if (!modelDefinition) {
15
- throw new Error(`Invalid model: ${model}`);
16
- }
17
- if (typeof modelDefinition === "string") {
18
- return {
19
- url: `https://api-inference.huggingface.co/models/${modelDefinition}`,
20
- authorization: `Bearer ${HF_ACCESS_TOKEN}`,
21
- weight: 1,
22
- };
23
- }
24
- if (!modelDefinition.endpoints) {
25
  return {
26
- url: `https://api-inference.huggingface.co/models/${modelDefinition.name}`,
27
  authorization: `Bearer ${HF_ACCESS_TOKEN}`,
28
  weight: 1,
29
  };
30
  }
31
- const endpoints = modelDefinition.endpoints;
32
  const totalWeight = sum(endpoints.map((e) => e.weight));
33
 
34
  let random = Math.random() * totalWeight;
 
1
  import { HF_ACCESS_TOKEN } from "$env/static/private";
2
  import { sum } from "$lib/utils/sum";
3
+ import type { BackendModel } from "./models";
4
 
5
  /**
6
  * Find a random load-balanced endpoint
7
  */
8
+ export function modelEndpoint(model: BackendModel): {
9
  url: string;
10
  authorization: string;
11
  weight: number;
12
  } {
13
+ if (!model.endpoints) {
 
 
 
 
 
 
 
 
 
 
 
14
  return {
15
+ url: `https://api-inference.huggingface.co/models/${model.name}`,
16
  authorization: `Bearer ${HF_ACCESS_TOKEN}`,
17
  weight: 1,
18
  };
19
  }
20
+ const endpoints = model.endpoints;
21
  const totalWeight = sum(endpoints.map((e) => e.weight));
22
 
23
  let random = Math.random() * totalWeight;
src/lib/server/models.ts CHANGED
@@ -4,6 +4,9 @@ import { z } from "zod";
4
  const modelsRaw = z
5
  .array(
6
  z.object({
 
 
 
7
  name: z.string().min(1),
8
  displayName: z.string().min(1).optional(),
9
  description: z.string().min(1).optional(),
@@ -46,6 +49,7 @@ const modelsRaw = z
46
  export const models = await Promise.all(
47
  modelsRaw.map(async (m) => ({
48
  ...m,
 
49
  displayName: m.displayName || m.name,
50
  preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
51
  }))
 
4
  const modelsRaw = z
5
  .array(
6
  z.object({
7
+ /** Used as an identifier in DB */
8
+ id: z.string().optional(),
9
+ /** Used to link to the model page, and for inference */
10
  name: z.string().min(1),
11
  displayName: z.string().min(1).optional(),
12
  description: z.string().min(1).optional(),
 
49
  export const models = await Promise.all(
50
  modelsRaw.map(async (m) => ({
51
  ...m,
52
+ id: m.id || m.name,
53
  displayName: m.displayName || m.name,
54
  preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
55
  }))
src/lib/types/Model.ts CHANGED
@@ -2,6 +2,7 @@ import type { BackendModel } from "$lib/server/models";
2
 
3
  export type Model = Pick<
4
  BackendModel,
 
5
  | "name"
6
  | "displayName"
7
  | "websiteUrl"
 
2
 
3
  export type Model = Pick<
4
  BackendModel,
5
+ | "id"
6
  | "name"
7
  | "displayName"
8
  | "websiteUrl"
src/lib/utils/models.ts CHANGED
@@ -2,9 +2,9 @@ import type { Model } from "$lib/types/Model";
2
  import { z } from "zod";
3
 
4
  export const findCurrentModel = (models: Model[], name?: string) =>
5
- models.find((m) => m.name === name) ?? models[0];
6
 
7
  export const validateModel = (models: Model[]) => {
8
  // Zod enum function requires 2 parameters
9
- return z.enum([models[0].name, ...models.slice(1).map((m) => m.name)]);
10
  };
 
2
  import { z } from "zod";
3
 
4
  export const findCurrentModel = (models: Model[], name?: string) =>
5
+ models.find((m) => m.id === name) ?? models[0];
6
 
7
  export const validateModel = (models: Model[]) => {
8
  // Zod enum function requires 2 parameters
9
+ return z.enum([models[0].id, ...models.slice(1).map((m) => m.id)]);
10
  };
src/routes/+layout.server.ts CHANGED
@@ -50,9 +50,10 @@ export const load: LayoutServerLoad = async ({ locals, depends, url }) => {
50
  settings: {
51
  shareConversationsWithModelAuthors: settings?.shareConversationsWithModelAuthors ?? true,
52
  ethicsModalAcceptedAt: settings?.ethicsModalAcceptedAt ?? null,
53
- activeModel: settings?.activeModel ?? defaultModel.name,
54
  },
55
  models: models.map((model) => ({
 
56
  name: model.name,
57
  websiteUrl: model.websiteUrl,
58
  datasetName: model.datasetName,
 
50
  settings: {
51
  shareConversationsWithModelAuthors: settings?.shareConversationsWithModelAuthors ?? true,
52
  ethicsModalAcceptedAt: settings?.ethicsModalAcceptedAt ?? null,
53
+ activeModel: settings?.activeModel ?? defaultModel.id,
54
  },
55
  models: models.map((model) => ({
56
+ id: model.id,
57
  name: model.name,
58
  websiteUrl: model.websiteUrl,
59
  datasetName: model.datasetName,
src/routes/conversation/+server.ts CHANGED
@@ -5,7 +5,7 @@ import { error, redirect } from "@sveltejs/kit";
5
  import { base } from "$app/paths";
6
  import { z } from "zod";
7
  import type { Message } from "$lib/types/Message";
8
- import { defaultModel, models } from "$lib/server/models";
9
  import { validateModel } from "$lib/utils/models";
10
 
11
  export const POST: RequestHandler = async (input) => {
@@ -17,7 +17,7 @@ export const POST: RequestHandler = async (input) => {
17
  const values = z
18
  .object({
19
  fromShare: z.string().optional(),
20
- model: validateModel(models).default(defaultModel.name),
21
  })
22
  .parse(JSON.parse(body));
23
 
 
5
  import { base } from "$app/paths";
6
  import { z } from "zod";
7
  import type { Message } from "$lib/types/Message";
8
+ import { models } from "$lib/server/models";
9
  import { validateModel } from "$lib/utils/models";
10
 
11
  export const POST: RequestHandler = async (input) => {
 
17
  const values = z
18
  .object({
19
  fromShare: z.string().optional(),
20
+ model: validateModel(models),
21
  })
22
  .parse(JSON.parse(body));
23
 
src/routes/conversation/[id]/+page.svelte CHANGED
@@ -36,7 +36,7 @@
36
  model: $page.url.href,
37
  inputs,
38
  parameters: {
39
- ...data.models.find((m) => m.name === data.model)?.parameters,
40
  return_full_text: false,
41
  },
42
  },
 
36
  model: $page.url.href,
37
  inputs,
38
  parameters: {
39
+ ...data.models.find((m) => m.id === data.model)?.parameters,
40
  return_full_text: false,
41
  },
42
  },
src/routes/conversation/[id]/+server.ts CHANGED
@@ -3,7 +3,7 @@ import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken.js";
3
  import { abortedGenerations } from "$lib/server/abortedGenerations.js";
4
  import { collections } from "$lib/server/database.js";
5
  import { modelEndpoint } from "$lib/server/modelEndpoint.js";
6
- import { defaultModel, models } from "$lib/server/models.js";
7
  import type { Message } from "$lib/types/Message.js";
8
  import { concatUint8Arrays } from "$lib/utils/concatUint8Arrays.js";
9
  import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
@@ -28,7 +28,11 @@ export async function POST({ request, fetch, locals, params }) {
28
  throw error(404, "Conversation not found");
29
  }
30
 
31
- const model = conv.model ?? defaultModel.name;
 
 
 
 
32
 
33
  const json = await request.json();
34
  const {
@@ -61,20 +65,7 @@ export async function POST({ request, fetch, locals, params }) {
61
  ];
62
  })() satisfies Message[];
63
 
64
- // Todo: on-the-fly migration, remove later
65
- for (const message of messages) {
66
- if (!message.id) {
67
- message.id = crypto.randomUUID();
68
- }
69
- }
70
-
71
- const modelInfo = models.find((m) => m.name === model);
72
-
73
- if (!modelInfo) {
74
- throw error(400, "Model not availalbe anymore");
75
- }
76
-
77
- const prompt = buildPrompt(messages, modelInfo);
78
 
79
  const randomEndpoint = modelEndpoint(model);
80
 
@@ -112,7 +103,7 @@ export async function POST({ request, fetch, locals, params }) {
112
  PUBLIC_SEP_TOKEN
113
  ).trimEnd();
114
 
115
- for (const stop of [...(modelInfo?.parameters?.stop ?? []), "<|endoftext|>"]) {
116
  if (generated_text.endsWith(stop)) {
117
  generated_text = generated_text.slice(0, -stop.length).trimEnd();
118
  }
 
3
  import { abortedGenerations } from "$lib/server/abortedGenerations.js";
4
  import { collections } from "$lib/server/database.js";
5
  import { modelEndpoint } from "$lib/server/modelEndpoint.js";
6
+ import { models } from "$lib/server/models.js";
7
  import type { Message } from "$lib/types/Message.js";
8
  import { concatUint8Arrays } from "$lib/utils/concatUint8Arrays.js";
9
  import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
 
28
  throw error(404, "Conversation not found");
29
  }
30
 
31
+ const model = models.find((m) => m.id === conv.model);
32
+
33
+ if (!model) {
34
+ throw error(400, "Model not availalbe anymore");
35
+ }
36
 
37
  const json = await request.json();
38
  const {
 
65
  ];
66
  })() satisfies Message[];
67
 
68
+ const prompt = buildPrompt(messages, model);
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  const randomEndpoint = modelEndpoint(model);
71
 
 
103
  PUBLIC_SEP_TOKEN
104
  ).trimEnd();
105
 
106
+ for (const stop of [...(model?.parameters?.stop ?? []), "<|endoftext|>"]) {
107
  if (generated_text.endsWith(stop)) {
108
  generated_text = generated_text.slice(0, -stop.length).trimEnd();
109
  }
src/routes/conversation/[id]/message/[messageId]/prompt/+server.ts CHANGED
@@ -24,7 +24,7 @@ export async function GET({ params, locals }) {
24
  throw error(404, "Message not found");
25
  }
26
 
27
- const model = models.find((m) => m.name === conv.model);
28
 
29
  if (!model) {
30
  throw error(404, "Conversation model not found");
@@ -37,7 +37,7 @@ export async function GET({ params, locals }) {
37
  {
38
  note: "This is a preview of the prompt that will be sent to the model when retrying the message. It may differ from what was sent in the past if the parameters have been updated since",
39
  prompt,
40
- model: model.name,
41
  parameters: {
42
  ...model.parameters,
43
  return_full_text: false,
 
24
  throw error(404, "Message not found");
25
  }
26
 
27
+ const model = models.find((m) => m.id === conv.model);
28
 
29
  if (!model) {
30
  throw error(404, "Conversation model not found");
 
37
  {
38
  note: "This is a preview of the prompt that will be sent to the model when retrying the message. It may differ from what was sent in the past if the parameters have been updated since",
39
  prompt,
40
+ model: model.id,
41
  parameters: {
42
  ...model.parameters,
43
  return_full_text: false,
src/routes/conversation/[id]/summarize/+server.ts CHANGED
@@ -34,7 +34,7 @@ export async function POST({ params, locals, fetch }) {
34
  return_full_text: false,
35
  };
36
 
37
- const endpoint = modelEndpoint(defaultModel.name);
38
  let { generated_text } = await textGeneration(
39
  {
40
  model: endpoint.url,
 
34
  return_full_text: false,
35
  };
36
 
37
+ const endpoint = modelEndpoint(defaultModel);
38
  let { generated_text } = await textGeneration(
39
  {
40
  model: endpoint.url,
src/routes/r/[id]/message/[messageId]/prompt/+server.ts CHANGED
@@ -20,7 +20,7 @@ export async function GET({ params }) {
20
  throw error(404, "Message not found");
21
  }
22
 
23
- const model = models.find((m) => m.name === conv.model);
24
 
25
  if (!model) {
26
  throw error(404, "Conversation model not found");
@@ -33,7 +33,7 @@ export async function GET({ params }) {
33
  {
34
  note: "This is a preview of the prompt that will be sent to the model when retrying the message. It may differ from what was sent in the past if the parameters have been updated since",
35
  prompt,
36
- model: model.name,
37
  parameters: {
38
  ...model.parameters,
39
  return_full_text: false,
 
20
  throw error(404, "Message not found");
21
  }
22
 
23
+ const model = models.find((m) => m.id === conv.model);
24
 
25
  if (!model) {
26
  throw error(404, "Conversation model not found");
 
33
  {
34
  note: "This is a preview of the prompt that will be sent to the model when retrying the message. It may differ from what was sent in the past if the parameters have been updated since",
35
  prompt,
36
+ model: model.id,
37
  parameters: {
38
  ...model.parameters,
39
  return_full_text: false,
src/routes/settings/+page.server.ts CHANGED
@@ -18,7 +18,7 @@ export const actions = {
18
  .parse({
19
  shareConversationsWithModelAuthors: formData.get("shareConversationsWithModelAuthors"),
20
  ethicsModalAccepted: formData.get("ethicsModalAccepted"),
21
- activeModel: formData.get("activeModel") ?? defaultModel.name,
22
  });
23
 
24
  await collections.settings.updateOne(
 
18
  .parse({
19
  shareConversationsWithModelAuthors: formData.get("shareConversationsWithModelAuthors"),
20
  ethicsModalAccepted: formData.get("ethicsModalAccepted"),
21
+ activeModel: formData.get("activeModel") ?? defaultModel.id,
22
  });
23
 
24
  await collections.settings.updateOne(