Spaces:
Running
Running
✨ Add model id to be able to switch models while keeping conversations valid (#181)
Browse files- src/lib/components/ModelsModal.svelte +5 -5
- src/lib/server/modelEndpoint.ts +5 -16
- src/lib/server/models.ts +4 -0
- src/lib/types/Model.ts +1 -0
- src/lib/utils/models.ts +2 -2
- src/routes/+layout.server.ts +2 -1
- src/routes/conversation/+server.ts +2 -2
- src/routes/conversation/[id]/+page.svelte +1 -1
- src/routes/conversation/[id]/+server.ts +8 -17
- src/routes/conversation/[id]/message/[messageId]/prompt/+server.ts +2 -2
- src/routes/conversation/[id]/summarize/+server.ts +1 -1
- src/routes/r/[id]/message/[messageId]/prompt/+server.ts +2 -2
- src/routes/settings/+page.server.ts +1 -1
src/lib/components/ModelsModal.svelte
CHANGED
@@ -13,7 +13,7 @@
|
|
13 |
export let settings: LayoutData["settings"];
|
14 |
export let models: Array<Model>;
|
15 |
|
16 |
-
let
|
17 |
|
18 |
const dispatch = createEventDispatcher<{ close: void }>();
|
19 |
</script>
|
@@ -40,7 +40,7 @@
|
|
40 |
<div class="space-y-4">
|
41 |
{#each models as model}
|
42 |
<div
|
43 |
-
class="rounded-xl border border-gray-100 {model.
|
44 |
? 'bg-gradient-to-r from-yellow-200/40 via-yellow-500/10'
|
45 |
: ''}"
|
46 |
>
|
@@ -49,8 +49,8 @@
|
|
49 |
type="radio"
|
50 |
class="sr-only"
|
51 |
name="activeModel"
|
52 |
-
value={model.
|
53 |
-
bind:group={
|
54 |
/>
|
55 |
<span>
|
56 |
<span class="text-md block font-semibold leading-tight text-gray-800"
|
@@ -61,7 +61,7 @@
|
|
61 |
{/if}
|
62 |
</span>
|
63 |
<CarbonCheckmark
|
64 |
-
class="-mr-1 -mt-1 ml-auto shrink-0 text-xl {model.
|
65 |
? 'text-yellow-400'
|
66 |
: 'text-transparent group-hover:text-gray-200'}"
|
67 |
/>
|
|
|
13 |
export let settings: LayoutData["settings"];
|
14 |
export let models: Array<Model>;
|
15 |
|
16 |
+
let selectedModelId = settings.activeModel;
|
17 |
|
18 |
const dispatch = createEventDispatcher<{ close: void }>();
|
19 |
</script>
|
|
|
40 |
<div class="space-y-4">
|
41 |
{#each models as model}
|
42 |
<div
|
43 |
+
class="rounded-xl border border-gray-100 {model.id === selectedModelId
|
44 |
? 'bg-gradient-to-r from-yellow-200/40 via-yellow-500/10'
|
45 |
: ''}"
|
46 |
>
|
|
|
49 |
type="radio"
|
50 |
class="sr-only"
|
51 |
name="activeModel"
|
52 |
+
value={model.id}
|
53 |
+
bind:group={selectedModelId}
|
54 |
/>
|
55 |
<span>
|
56 |
<span class="text-md block font-semibold leading-tight text-gray-800"
|
|
|
61 |
{/if}
|
62 |
</span>
|
63 |
<CarbonCheckmark
|
64 |
+
class="-mr-1 -mt-1 ml-auto shrink-0 text-xl {model.id === selectedModelId
|
65 |
? 'text-yellow-400'
|
66 |
: 'text-transparent group-hover:text-gray-200'}"
|
67 |
/>
|
src/lib/server/modelEndpoint.ts
CHANGED
@@ -1,34 +1,23 @@
|
|
1 |
import { HF_ACCESS_TOKEN } from "$env/static/private";
|
2 |
import { sum } from "$lib/utils/sum";
|
3 |
-
import {
|
4 |
|
5 |
/**
|
6 |
* Find a random load-balanced endpoint
|
7 |
*/
|
8 |
-
export function modelEndpoint(model:
|
9 |
url: string;
|
10 |
authorization: string;
|
11 |
weight: number;
|
12 |
} {
|
13 |
-
|
14 |
-
if (!modelDefinition) {
|
15 |
-
throw new Error(`Invalid model: ${model}`);
|
16 |
-
}
|
17 |
-
if (typeof modelDefinition === "string") {
|
18 |
-
return {
|
19 |
-
url: `https://api-inference.huggingface.co/models/${modelDefinition}`,
|
20 |
-
authorization: `Bearer ${HF_ACCESS_TOKEN}`,
|
21 |
-
weight: 1,
|
22 |
-
};
|
23 |
-
}
|
24 |
-
if (!modelDefinition.endpoints) {
|
25 |
return {
|
26 |
-
url: `https://api-inference.huggingface.co/models/${
|
27 |
authorization: `Bearer ${HF_ACCESS_TOKEN}`,
|
28 |
weight: 1,
|
29 |
};
|
30 |
}
|
31 |
-
const endpoints =
|
32 |
const totalWeight = sum(endpoints.map((e) => e.weight));
|
33 |
|
34 |
let random = Math.random() * totalWeight;
|
|
|
1 |
import { HF_ACCESS_TOKEN } from "$env/static/private";
|
2 |
import { sum } from "$lib/utils/sum";
|
3 |
+
import type { BackendModel } from "./models";
|
4 |
|
5 |
/**
|
6 |
* Find a random load-balanced endpoint
|
7 |
*/
|
8 |
+
export function modelEndpoint(model: BackendModel): {
|
9 |
url: string;
|
10 |
authorization: string;
|
11 |
weight: number;
|
12 |
} {
|
13 |
+
if (!model.endpoints) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
return {
|
15 |
+
url: `https://api-inference.huggingface.co/models/${model.name}`,
|
16 |
authorization: `Bearer ${HF_ACCESS_TOKEN}`,
|
17 |
weight: 1,
|
18 |
};
|
19 |
}
|
20 |
+
const endpoints = model.endpoints;
|
21 |
const totalWeight = sum(endpoints.map((e) => e.weight));
|
22 |
|
23 |
let random = Math.random() * totalWeight;
|
src/lib/server/models.ts
CHANGED
@@ -4,6 +4,9 @@ import { z } from "zod";
|
|
4 |
const modelsRaw = z
|
5 |
.array(
|
6 |
z.object({
|
|
|
|
|
|
|
7 |
name: z.string().min(1),
|
8 |
displayName: z.string().min(1).optional(),
|
9 |
description: z.string().min(1).optional(),
|
@@ -46,6 +49,7 @@ const modelsRaw = z
|
|
46 |
export const models = await Promise.all(
|
47 |
modelsRaw.map(async (m) => ({
|
48 |
...m,
|
|
|
49 |
displayName: m.displayName || m.name,
|
50 |
preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
|
51 |
}))
|
|
|
4 |
const modelsRaw = z
|
5 |
.array(
|
6 |
z.object({
|
7 |
+
/** Used as an identifier in DB */
|
8 |
+
id: z.string().optional(),
|
9 |
+
/** Used to link to the model page, and for inference */
|
10 |
name: z.string().min(1),
|
11 |
displayName: z.string().min(1).optional(),
|
12 |
description: z.string().min(1).optional(),
|
|
|
49 |
export const models = await Promise.all(
|
50 |
modelsRaw.map(async (m) => ({
|
51 |
...m,
|
52 |
+
id: m.id || m.name,
|
53 |
displayName: m.displayName || m.name,
|
54 |
preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
|
55 |
}))
|
src/lib/types/Model.ts
CHANGED
@@ -2,6 +2,7 @@ import type { BackendModel } from "$lib/server/models";
|
|
2 |
|
3 |
export type Model = Pick<
|
4 |
BackendModel,
|
|
|
5 |
| "name"
|
6 |
| "displayName"
|
7 |
| "websiteUrl"
|
|
|
2 |
|
3 |
export type Model = Pick<
|
4 |
BackendModel,
|
5 |
+
| "id"
|
6 |
| "name"
|
7 |
| "displayName"
|
8 |
| "websiteUrl"
|
src/lib/utils/models.ts
CHANGED
@@ -2,9 +2,9 @@ import type { Model } from "$lib/types/Model";
|
|
2 |
import { z } from "zod";
|
3 |
|
4 |
export const findCurrentModel = (models: Model[], name?: string) =>
|
5 |
-
models.find((m) => m.
|
6 |
|
7 |
export const validateModel = (models: Model[]) => {
|
8 |
// Zod enum function requires 2 parameters
|
9 |
-
return z.enum([models[0].
|
10 |
};
|
|
|
2 |
import { z } from "zod";
|
3 |
|
4 |
export const findCurrentModel = (models: Model[], name?: string) =>
|
5 |
+
models.find((m) => m.id === name) ?? models[0];
|
6 |
|
7 |
export const validateModel = (models: Model[]) => {
|
8 |
// Zod enum function requires 2 parameters
|
9 |
+
return z.enum([models[0].id, ...models.slice(1).map((m) => m.id)]);
|
10 |
};
|
src/routes/+layout.server.ts
CHANGED
@@ -50,9 +50,10 @@ export const load: LayoutServerLoad = async ({ locals, depends, url }) => {
|
|
50 |
settings: {
|
51 |
shareConversationsWithModelAuthors: settings?.shareConversationsWithModelAuthors ?? true,
|
52 |
ethicsModalAcceptedAt: settings?.ethicsModalAcceptedAt ?? null,
|
53 |
-
activeModel: settings?.activeModel ?? defaultModel.
|
54 |
},
|
55 |
models: models.map((model) => ({
|
|
|
56 |
name: model.name,
|
57 |
websiteUrl: model.websiteUrl,
|
58 |
datasetName: model.datasetName,
|
|
|
50 |
settings: {
|
51 |
shareConversationsWithModelAuthors: settings?.shareConversationsWithModelAuthors ?? true,
|
52 |
ethicsModalAcceptedAt: settings?.ethicsModalAcceptedAt ?? null,
|
53 |
+
activeModel: settings?.activeModel ?? defaultModel.id,
|
54 |
},
|
55 |
models: models.map((model) => ({
|
56 |
+
id: model.id,
|
57 |
name: model.name,
|
58 |
websiteUrl: model.websiteUrl,
|
59 |
datasetName: model.datasetName,
|
src/routes/conversation/+server.ts
CHANGED
@@ -5,7 +5,7 @@ import { error, redirect } from "@sveltejs/kit";
|
|
5 |
import { base } from "$app/paths";
|
6 |
import { z } from "zod";
|
7 |
import type { Message } from "$lib/types/Message";
|
8 |
-
import {
|
9 |
import { validateModel } from "$lib/utils/models";
|
10 |
|
11 |
export const POST: RequestHandler = async (input) => {
|
@@ -17,7 +17,7 @@ export const POST: RequestHandler = async (input) => {
|
|
17 |
const values = z
|
18 |
.object({
|
19 |
fromShare: z.string().optional(),
|
20 |
-
model: validateModel(models)
|
21 |
})
|
22 |
.parse(JSON.parse(body));
|
23 |
|
|
|
5 |
import { base } from "$app/paths";
|
6 |
import { z } from "zod";
|
7 |
import type { Message } from "$lib/types/Message";
|
8 |
+
import { models } from "$lib/server/models";
|
9 |
import { validateModel } from "$lib/utils/models";
|
10 |
|
11 |
export const POST: RequestHandler = async (input) => {
|
|
|
17 |
const values = z
|
18 |
.object({
|
19 |
fromShare: z.string().optional(),
|
20 |
+
model: validateModel(models),
|
21 |
})
|
22 |
.parse(JSON.parse(body));
|
23 |
|
src/routes/conversation/[id]/+page.svelte
CHANGED
@@ -36,7 +36,7 @@
|
|
36 |
model: $page.url.href,
|
37 |
inputs,
|
38 |
parameters: {
|
39 |
-
...data.models.find((m) => m.
|
40 |
return_full_text: false,
|
41 |
},
|
42 |
},
|
|
|
36 |
model: $page.url.href,
|
37 |
inputs,
|
38 |
parameters: {
|
39 |
+
...data.models.find((m) => m.id === data.model)?.parameters,
|
40 |
return_full_text: false,
|
41 |
},
|
42 |
},
|
src/routes/conversation/[id]/+server.ts
CHANGED
@@ -3,7 +3,7 @@ import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken.js";
|
|
3 |
import { abortedGenerations } from "$lib/server/abortedGenerations.js";
|
4 |
import { collections } from "$lib/server/database.js";
|
5 |
import { modelEndpoint } from "$lib/server/modelEndpoint.js";
|
6 |
-
import {
|
7 |
import type { Message } from "$lib/types/Message.js";
|
8 |
import { concatUint8Arrays } from "$lib/utils/concatUint8Arrays.js";
|
9 |
import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
|
@@ -28,7 +28,11 @@ export async function POST({ request, fetch, locals, params }) {
|
|
28 |
throw error(404, "Conversation not found");
|
29 |
}
|
30 |
|
31 |
-
const model =
|
|
|
|
|
|
|
|
|
32 |
|
33 |
const json = await request.json();
|
34 |
const {
|
@@ -61,20 +65,7 @@ export async function POST({ request, fetch, locals, params }) {
|
|
61 |
];
|
62 |
})() satisfies Message[];
|
63 |
|
64 |
-
|
65 |
-
for (const message of messages) {
|
66 |
-
if (!message.id) {
|
67 |
-
message.id = crypto.randomUUID();
|
68 |
-
}
|
69 |
-
}
|
70 |
-
|
71 |
-
const modelInfo = models.find((m) => m.name === model);
|
72 |
-
|
73 |
-
if (!modelInfo) {
|
74 |
-
throw error(400, "Model not availalbe anymore");
|
75 |
-
}
|
76 |
-
|
77 |
-
const prompt = buildPrompt(messages, modelInfo);
|
78 |
|
79 |
const randomEndpoint = modelEndpoint(model);
|
80 |
|
@@ -112,7 +103,7 @@ export async function POST({ request, fetch, locals, params }) {
|
|
112 |
PUBLIC_SEP_TOKEN
|
113 |
).trimEnd();
|
114 |
|
115 |
-
for (const stop of [...(
|
116 |
if (generated_text.endsWith(stop)) {
|
117 |
generated_text = generated_text.slice(0, -stop.length).trimEnd();
|
118 |
}
|
|
|
3 |
import { abortedGenerations } from "$lib/server/abortedGenerations.js";
|
4 |
import { collections } from "$lib/server/database.js";
|
5 |
import { modelEndpoint } from "$lib/server/modelEndpoint.js";
|
6 |
+
import { models } from "$lib/server/models.js";
|
7 |
import type { Message } from "$lib/types/Message.js";
|
8 |
import { concatUint8Arrays } from "$lib/utils/concatUint8Arrays.js";
|
9 |
import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
|
|
|
28 |
throw error(404, "Conversation not found");
|
29 |
}
|
30 |
|
31 |
+
const model = models.find((m) => m.id === conv.model);
|
32 |
+
|
33 |
+
if (!model) {
|
34 |
+
throw error(400, "Model not availalbe anymore");
|
35 |
+
}
|
36 |
|
37 |
const json = await request.json();
|
38 |
const {
|
|
|
65 |
];
|
66 |
})() satisfies Message[];
|
67 |
|
68 |
+
const prompt = buildPrompt(messages, model);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
const randomEndpoint = modelEndpoint(model);
|
71 |
|
|
|
103 |
PUBLIC_SEP_TOKEN
|
104 |
).trimEnd();
|
105 |
|
106 |
+
for (const stop of [...(model?.parameters?.stop ?? []), "<|endoftext|>"]) {
|
107 |
if (generated_text.endsWith(stop)) {
|
108 |
generated_text = generated_text.slice(0, -stop.length).trimEnd();
|
109 |
}
|
src/routes/conversation/[id]/message/[messageId]/prompt/+server.ts
CHANGED
@@ -24,7 +24,7 @@ export async function GET({ params, locals }) {
|
|
24 |
throw error(404, "Message not found");
|
25 |
}
|
26 |
|
27 |
-
const model = models.find((m) => m.
|
28 |
|
29 |
if (!model) {
|
30 |
throw error(404, "Conversation model not found");
|
@@ -37,7 +37,7 @@ export async function GET({ params, locals }) {
|
|
37 |
{
|
38 |
note: "This is a preview of the prompt that will be sent to the model when retrying the message. It may differ from what was sent in the past if the parameters have been updated since",
|
39 |
prompt,
|
40 |
-
model: model.
|
41 |
parameters: {
|
42 |
...model.parameters,
|
43 |
return_full_text: false,
|
|
|
24 |
throw error(404, "Message not found");
|
25 |
}
|
26 |
|
27 |
+
const model = models.find((m) => m.id === conv.model);
|
28 |
|
29 |
if (!model) {
|
30 |
throw error(404, "Conversation model not found");
|
|
|
37 |
{
|
38 |
note: "This is a preview of the prompt that will be sent to the model when retrying the message. It may differ from what was sent in the past if the parameters have been updated since",
|
39 |
prompt,
|
40 |
+
model: model.id,
|
41 |
parameters: {
|
42 |
...model.parameters,
|
43 |
return_full_text: false,
|
src/routes/conversation/[id]/summarize/+server.ts
CHANGED
@@ -34,7 +34,7 @@ export async function POST({ params, locals, fetch }) {
|
|
34 |
return_full_text: false,
|
35 |
};
|
36 |
|
37 |
-
const endpoint = modelEndpoint(defaultModel
|
38 |
let { generated_text } = await textGeneration(
|
39 |
{
|
40 |
model: endpoint.url,
|
|
|
34 |
return_full_text: false,
|
35 |
};
|
36 |
|
37 |
+
const endpoint = modelEndpoint(defaultModel);
|
38 |
let { generated_text } = await textGeneration(
|
39 |
{
|
40 |
model: endpoint.url,
|
src/routes/r/[id]/message/[messageId]/prompt/+server.ts
CHANGED
@@ -20,7 +20,7 @@ export async function GET({ params }) {
|
|
20 |
throw error(404, "Message not found");
|
21 |
}
|
22 |
|
23 |
-
const model = models.find((m) => m.
|
24 |
|
25 |
if (!model) {
|
26 |
throw error(404, "Conversation model not found");
|
@@ -33,7 +33,7 @@ export async function GET({ params }) {
|
|
33 |
{
|
34 |
note: "This is a preview of the prompt that will be sent to the model when retrying the message. It may differ from what was sent in the past if the parameters have been updated since",
|
35 |
prompt,
|
36 |
-
model: model.
|
37 |
parameters: {
|
38 |
...model.parameters,
|
39 |
return_full_text: false,
|
|
|
20 |
throw error(404, "Message not found");
|
21 |
}
|
22 |
|
23 |
+
const model = models.find((m) => m.id === conv.model);
|
24 |
|
25 |
if (!model) {
|
26 |
throw error(404, "Conversation model not found");
|
|
|
33 |
{
|
34 |
note: "This is a preview of the prompt that will be sent to the model when retrying the message. It may differ from what was sent in the past if the parameters have been updated since",
|
35 |
prompt,
|
36 |
+
model: model.id,
|
37 |
parameters: {
|
38 |
...model.parameters,
|
39 |
return_full_text: false,
|
src/routes/settings/+page.server.ts
CHANGED
@@ -18,7 +18,7 @@ export const actions = {
|
|
18 |
.parse({
|
19 |
shareConversationsWithModelAuthors: formData.get("shareConversationsWithModelAuthors"),
|
20 |
ethicsModalAccepted: formData.get("ethicsModalAccepted"),
|
21 |
-
activeModel: formData.get("activeModel") ?? defaultModel.
|
22 |
});
|
23 |
|
24 |
await collections.settings.updateOne(
|
|
|
18 |
.parse({
|
19 |
shareConversationsWithModelAuthors: formData.get("shareConversationsWithModelAuthors"),
|
20 |
ethicsModalAccepted: formData.get("ethicsModalAccepted"),
|
21 |
+
activeModel: formData.get("activeModel") ?? defaultModel.id,
|
22 |
});
|
23 |
|
24 |
await collections.settings.updateOne(
|