Spaces:
Paused
Paused
Continue generation feature (#707)
Browse files* Initial work on continue feature
* Move continue button
* Fix websearch with continue
* Make it work with every model
* Update src/routes/conversation/[id]/+server.ts
Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
* fixes
* async all the things
* add reduce comment
* remove log
* Only show loading indicator if not continuing
---------
Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
- .env.template +4 -2
- src/lib/buildPrompt.ts +18 -16
- src/lib/components/ContinueBtn.svelte +13 -0
- src/lib/components/chat/ChatMessage.svelte +1 -0
- src/lib/components/chat/ChatMessages.svelte +2 -1
- src/lib/components/chat/ChatWindow.svelte +17 -2
- src/lib/server/endpoints/endpoints.ts +1 -0
- src/lib/server/endpoints/tgi/endpointTgi.ts +13 -2
- src/lib/types/Message.ts +1 -0
- src/routes/conversation/[id]/+page.svelte +76 -25
- src/routes/conversation/[id]/+server.ts +65 -35
.env.template
CHANGED
|
@@ -57,7 +57,8 @@ MODELS=`[
|
|
| 57 |
"repetition_penalty": 1.2,
|
| 58 |
"top_k": 50,
|
| 59 |
"truncate": 3072,
|
| 60 |
-
"max_new_tokens": 1024
|
|
|
|
| 61 |
}
|
| 62 |
},
|
| 63 |
{
|
|
@@ -116,7 +117,8 @@ MODELS=`[
|
|
| 116 |
"repetition_penalty": 1.2,
|
| 117 |
"top_k": 50,
|
| 118 |
"truncate": 4096,
|
| 119 |
-
"max_new_tokens": 4096
|
|
|
|
| 120 |
}
|
| 121 |
},
|
| 122 |
{
|
|
|
|
| 57 |
"repetition_penalty": 1.2,
|
| 58 |
"top_k": 50,
|
| 59 |
"truncate": 3072,
|
| 60 |
+
"max_new_tokens": 1024,
|
| 61 |
+
"stop" : ["</s>", " </s><s>[INST] "]
|
| 62 |
}
|
| 63 |
},
|
| 64 |
{
|
|
|
|
| 117 |
"repetition_penalty": 1.2,
|
| 118 |
"top_k": 50,
|
| 119 |
"truncate": 4096,
|
| 120 |
+
"max_new_tokens": 4096,
|
| 121 |
+
"stop": [" </s><s>[INST] "]
|
| 122 |
}
|
| 123 |
},
|
| 124 |
{
|
src/lib/buildPrompt.ts
CHANGED
|
@@ -13,6 +13,7 @@ interface buildPromptOptions {
|
|
| 13 |
webSearch?: WebSearch;
|
| 14 |
preprompt?: string;
|
| 15 |
files?: File[];
|
|
|
|
| 16 |
}
|
| 17 |
|
| 18 |
export async function buildPrompt({
|
|
@@ -22,37 +23,38 @@ export async function buildPrompt({
|
|
| 22 |
preprompt,
|
| 23 |
id,
|
| 24 |
}: buildPromptOptions): Promise<string> {
|
|
|
|
|
|
|
| 25 |
if (webSearch && webSearch.context) {
|
| 26 |
-
|
| 27 |
-
const
|
| 28 |
-
const previousUserMessages = messages.filter((el) => el.from === "user").slice(0, -1);
|
| 29 |
|
|
|
|
|
|
|
| 30 |
const previousQuestions =
|
| 31 |
previousUserMessages.length > 0
|
| 32 |
? `Previous questions: \n${previousUserMessages
|
| 33 |
.map(({ content }) => `- ${content}`)
|
| 34 |
.join("\n")}`
|
| 35 |
: "";
|
|
|
|
| 36 |
const currentDate = format(new Date(), "MMMM d, yyyy");
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
=====================
|
| 43 |
${webSearch.context}
|
| 44 |
=====================
|
| 45 |
${previousQuestions}
|
| 46 |
-
Answer the question: ${
|
| 47 |
-
|
| 48 |
-
},
|
| 49 |
-
];
|
| 50 |
}
|
| 51 |
-
|
| 52 |
// section to handle potential files input
|
| 53 |
if (model.multimodal) {
|
| 54 |
-
|
| 55 |
-
|
| 56 |
let content = el.content;
|
| 57 |
|
| 58 |
if (el.from === "user") {
|
|
@@ -83,7 +85,7 @@ export async function buildPrompt({
|
|
| 83 |
|
| 84 |
return (
|
| 85 |
model
|
| 86 |
-
.chatPromptRender({ messages, preprompt })
|
| 87 |
// Not super precise, but it's truncated in the model's backend anyway
|
| 88 |
.split(" ")
|
| 89 |
.slice(-(model.parameters?.truncate ?? 0))
|
|
|
|
| 13 |
webSearch?: WebSearch;
|
| 14 |
preprompt?: string;
|
| 15 |
files?: File[];
|
| 16 |
+
continue?: boolean;
|
| 17 |
}
|
| 18 |
|
| 19 |
export async function buildPrompt({
|
|
|
|
| 23 |
preprompt,
|
| 24 |
id,
|
| 25 |
}: buildPromptOptions): Promise<string> {
|
| 26 |
+
let modifiedMessages = [...messages];
|
| 27 |
+
|
| 28 |
if (webSearch && webSearch.context) {
|
| 29 |
+
// find index of the last user message
|
| 30 |
+
const lastUsrMsgIndex = modifiedMessages.map((el) => el.from).lastIndexOf("user");
|
|
|
|
| 31 |
|
| 32 |
+
// combine all the other previous questions into one string
|
| 33 |
+
const previousUserMessages = modifiedMessages.filter((el) => el.from === "user").slice(0, -1);
|
| 34 |
const previousQuestions =
|
| 35 |
previousUserMessages.length > 0
|
| 36 |
? `Previous questions: \n${previousUserMessages
|
| 37 |
.map(({ content }) => `- ${content}`)
|
| 38 |
.join("\n")}`
|
| 39 |
: "";
|
| 40 |
+
|
| 41 |
const currentDate = format(new Date(), "MMMM d, yyyy");
|
| 42 |
+
|
| 43 |
+
// update the last user message directly (that way if the last message is an assistant partial answer, we keep the beginning of that answer)
|
| 44 |
+
modifiedMessages[lastUsrMsgIndex] = {
|
| 45 |
+
from: "user",
|
| 46 |
+
content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results:
|
| 47 |
=====================
|
| 48 |
${webSearch.context}
|
| 49 |
=====================
|
| 50 |
${previousQuestions}
|
| 51 |
+
Answer the question: ${messages[lastUsrMsgIndex].content} `,
|
| 52 |
+
};
|
|
|
|
|
|
|
| 53 |
}
|
|
|
|
| 54 |
// section to handle potential files input
|
| 55 |
if (model.multimodal) {
|
| 56 |
+
modifiedMessages = await Promise.all(
|
| 57 |
+
modifiedMessages.map(async (el) => {
|
| 58 |
let content = el.content;
|
| 59 |
|
| 60 |
if (el.from === "user") {
|
|
|
|
| 85 |
|
| 86 |
return (
|
| 87 |
model
|
| 88 |
+
.chatPromptRender({ messages: modifiedMessages, preprompt })
|
| 89 |
// Not super precise, but it's truncated in the model's backend anyway
|
| 90 |
.split(" ")
|
| 91 |
.slice(-(model.parameters?.truncate ?? 0))
|
src/lib/components/ContinueBtn.svelte
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<script lang="ts">
|
| 2 |
+
import CarbonContinue from "~icons/carbon/continue";
|
| 3 |
+
|
| 4 |
+
export let classNames = "";
|
| 5 |
+
</script>
|
| 6 |
+
|
| 7 |
+
<button
|
| 8 |
+
type="button"
|
| 9 |
+
on:click
|
| 10 |
+
class="btn flex h-8 rounded-lg border bg-white px-3 py-1 text-gray-500 shadow-sm transition-all hover:bg-gray-100 dark:border-gray-600 dark:bg-gray-700 dark:text-gray-300 dark:hover:bg-gray-600 {classNames}"
|
| 11 |
+
>
|
| 12 |
+
<CarbonContinue class="mr-2 text-xs " /> Continue
|
| 13 |
+
</button>
|
src/lib/components/chat/ChatMessage.svelte
CHANGED
|
@@ -13,6 +13,7 @@
|
|
| 13 |
import CarbonDownload from "~icons/carbon/download";
|
| 14 |
import CarbonThumbsUp from "~icons/carbon/thumbs-up";
|
| 15 |
import CarbonThumbsDown from "~icons/carbon/thumbs-down";
|
|
|
|
| 16 |
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
|
| 17 |
import type { Model } from "$lib/types/Model";
|
| 18 |
|
|
|
|
| 13 |
import CarbonDownload from "~icons/carbon/download";
|
| 14 |
import CarbonThumbsUp from "~icons/carbon/thumbs-up";
|
| 15 |
import CarbonThumbsDown from "~icons/carbon/thumbs-down";
|
| 16 |
+
|
| 17 |
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
|
| 18 |
import type { Model } from "$lib/types/Model";
|
| 19 |
|
src/lib/components/chat/ChatMessages.svelte
CHANGED
|
@@ -54,11 +54,12 @@
|
|
| 54 |
webSearchMessages={i === messages.length - 1 ? webSearchMessages : []}
|
| 55 |
on:retry
|
| 56 |
on:vote
|
|
|
|
| 57 |
/>
|
| 58 |
{:else}
|
| 59 |
<ChatIntroduction {models} {currentModel} on:message />
|
| 60 |
{/each}
|
| 61 |
-
{#if pending}
|
| 62 |
<ChatMessage
|
| 63 |
message={{ from: "assistant", content: "", id: randomUUID() }}
|
| 64 |
model={currentModel}
|
|
|
|
| 54 |
webSearchMessages={i === messages.length - 1 ? webSearchMessages : []}
|
| 55 |
on:retry
|
| 56 |
on:vote
|
| 57 |
+
on:continue
|
| 58 |
/>
|
| 59 |
{:else}
|
| 60 |
<ChatIntroduction {models} {currentModel} on:message />
|
| 61 |
{/each}
|
| 62 |
+
{#if pending && messages[messages.length - 1]?.from === "user"}
|
| 63 |
<ChatMessage
|
| 64 |
message={{ from: "assistant", content: "", id: randomUUID() }}
|
| 65 |
model={currentModel}
|
src/lib/components/chat/ChatWindow.svelte
CHANGED
|
@@ -24,6 +24,7 @@
|
|
| 24 |
import UploadBtn from "../UploadBtn.svelte";
|
| 25 |
import file2base64 from "$lib/utils/file2base64";
|
| 26 |
import { useSettingsStore } from "$lib/stores/settings";
|
|
|
|
| 27 |
|
| 28 |
export let messages: Message[] = [];
|
| 29 |
export let loading = false;
|
|
@@ -48,6 +49,7 @@
|
|
| 48 |
share: void;
|
| 49 |
stop: void;
|
| 50 |
retry: { id: Message["id"]; content: string };
|
|
|
|
| 51 |
}>();
|
| 52 |
|
| 53 |
const handleSubmit = () => {
|
|
@@ -124,6 +126,7 @@
|
|
| 124 |
}
|
| 125 |
}}
|
| 126 |
on:vote
|
|
|
|
| 127 |
on:retry={(ev) => {
|
| 128 |
if (!loading) dispatch("retry", ev.detail);
|
| 129 |
}}
|
|
@@ -173,8 +176,20 @@
|
|
| 173 |
content: messages[messages.length - 1].content,
|
| 174 |
})}
|
| 175 |
/>
|
| 176 |
-
{:else
|
| 177 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
{/if}
|
| 179 |
</div>
|
| 180 |
<form
|
|
|
|
| 24 |
import UploadBtn from "../UploadBtn.svelte";
|
| 25 |
import file2base64 from "$lib/utils/file2base64";
|
| 26 |
import { useSettingsStore } from "$lib/stores/settings";
|
| 27 |
+
import ContinueBtn from "../ContinueBtn.svelte";
|
| 28 |
|
| 29 |
export let messages: Message[] = [];
|
| 30 |
export let loading = false;
|
|
|
|
| 49 |
share: void;
|
| 50 |
stop: void;
|
| 51 |
retry: { id: Message["id"]; content: string };
|
| 52 |
+
continue: { id: Message["id"] };
|
| 53 |
}>();
|
| 54 |
|
| 55 |
const handleSubmit = () => {
|
|
|
|
| 126 |
}
|
| 127 |
}}
|
| 128 |
on:vote
|
| 129 |
+
on:continue
|
| 130 |
on:retry={(ev) => {
|
| 131 |
if (!loading) dispatch("retry", ev.detail);
|
| 132 |
}}
|
|
|
|
| 176 |
content: messages[messages.length - 1].content,
|
| 177 |
})}
|
| 178 |
/>
|
| 179 |
+
{:else}
|
| 180 |
+
<div class="ml-auto gap-2">
|
| 181 |
+
{#if currentModel.multimodal}
|
| 182 |
+
<UploadBtn bind:files classNames="ml-auto" />
|
| 183 |
+
{/if}
|
| 184 |
+
{#if messages && messages[messages.length - 1]?.interrupted && !isReadOnly}
|
| 185 |
+
<ContinueBtn
|
| 186 |
+
on:click={() =>
|
| 187 |
+
dispatch("continue", {
|
| 188 |
+
id: messages[messages.length - 1].id,
|
| 189 |
+
})}
|
| 190 |
+
/>
|
| 191 |
+
{/if}
|
| 192 |
+
</div>
|
| 193 |
{/if}
|
| 194 |
</div>
|
| 195 |
<form
|
src/lib/server/endpoints/endpoints.ts
CHANGED
|
@@ -14,6 +14,7 @@ interface EndpointParameters {
|
|
| 14 |
preprompt?: Conversation["preprompt"];
|
| 15 |
_id?: Conversation["_id"];
|
| 16 |
};
|
|
|
|
| 17 |
}
|
| 18 |
|
| 19 |
interface CommonEndpoint {
|
|
|
|
| 14 |
preprompt?: Conversation["preprompt"];
|
| 15 |
_id?: Conversation["_id"];
|
| 16 |
};
|
| 17 |
+
continue?: boolean;
|
| 18 |
}
|
| 19 |
|
| 20 |
interface CommonEndpoint {
|
src/lib/server/endpoints/tgi/endpointTgi.ts
CHANGED
|
@@ -15,8 +15,9 @@ export const endpointTgiParametersSchema = z.object({
|
|
| 15 |
|
| 16 |
export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
|
| 17 |
const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
|
| 18 |
-
|
| 19 |
-
|
|
|
|
| 20 |
messages: conversation.messages,
|
| 21 |
webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
|
| 22 |
preprompt: conversation.preprompt,
|
|
@@ -24,6 +25,16 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
|
|
| 24 |
id: conversation._id,
|
| 25 |
});
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
return textGenerationStream(
|
| 28 |
{
|
| 29 |
parameters: { ...model.parameters, return_full_text: false },
|
|
|
|
| 15 |
|
| 16 |
export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
|
| 17 |
const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
|
| 18 |
+
|
| 19 |
+
return async ({ conversation, continue: messageContinue }) => {
|
| 20 |
+
let prompt = await buildPrompt({
|
| 21 |
messages: conversation.messages,
|
| 22 |
webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
|
| 23 |
preprompt: conversation.preprompt,
|
|
|
|
| 25 |
id: conversation._id,
|
| 26 |
});
|
| 27 |
|
| 28 |
+
if (messageContinue) {
|
| 29 |
+
// start with the full prompt, and for each stop token, try to remove it from the end of the prompt
|
| 30 |
+
prompt = model.parameters.stop.reduce((acc: string, curr: string) => {
|
| 31 |
+
if (acc.endsWith(curr)) {
|
| 32 |
+
return acc.slice(0, acc.length - curr.length);
|
| 33 |
+
}
|
| 34 |
+
return acc;
|
| 35 |
+
}, prompt.trimEnd());
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
return textGenerationStream(
|
| 39 |
{
|
| 40 |
parameters: { ...model.parameters, return_full_text: false },
|
src/lib/types/Message.ts
CHANGED
|
@@ -11,4 +11,5 @@ export type Message = Partial<Timestamps> & {
|
|
| 11 |
webSearch?: WebSearch;
|
| 12 |
score?: -1 | 0 | 1;
|
| 13 |
files?: string[]; // can contain either the hash of the file or the b64 encoded image data on the client side when uploading
|
|
|
|
| 14 |
};
|
|
|
|
| 11 |
webSearch?: WebSearch;
|
| 12 |
score?: -1 | 0 | 1;
|
| 13 |
files?: string[]; // can contain either the hash of the file or the b64 encoded image data on the client side when uploading
|
| 14 |
+
interrupted?: boolean;
|
| 15 |
};
|
src/routes/conversation/[id]/+page.svelte
CHANGED
|
@@ -64,9 +64,17 @@
|
|
| 64 |
}
|
| 65 |
}
|
| 66 |
// this function is used to send new message to the backends
|
| 67 |
-
async function writeMessage(
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
try {
|
| 71 |
$isAborted = false;
|
| 72 |
loading = true;
|
|
@@ -74,13 +82,21 @@
|
|
| 74 |
|
| 75 |
// first we check if the messageId already exists, indicating a retry
|
| 76 |
|
| 77 |
-
let
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
}
|
| 83 |
|
|
|
|
|
|
|
| 84 |
const module = await import("browser-image-resizer");
|
| 85 |
|
| 86 |
// currently, only IDEFICS is supported by TGI
|
|
@@ -99,15 +115,31 @@
|
|
| 99 |
);
|
| 100 |
|
| 101 |
// slice up to the point of the retry
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
files = [];
|
| 113 |
|
|
@@ -115,9 +147,10 @@
|
|
| 115 |
method: "POST",
|
| 116 |
headers: { "Content-Type": "application/json" },
|
| 117 |
body: JSON.stringify({
|
| 118 |
-
inputs:
|
| 119 |
id: messageId,
|
| 120 |
is_retry: isRetry,
|
|
|
|
| 121 |
web_search: $webSearchParameters.useSearch,
|
| 122 |
files: isRetry ? undefined : resizedImages,
|
| 123 |
}),
|
|
@@ -282,37 +315,54 @@
|
|
| 282 |
// only used in case of creating new conversations (from the parent POST endpoint)
|
| 283 |
if ($pendingMessage) {
|
| 284 |
files = $pendingMessage.files;
|
| 285 |
-
await writeMessage($pendingMessage.content);
|
| 286 |
$pendingMessage = undefined;
|
| 287 |
}
|
| 288 |
});
|
| 289 |
|
| 290 |
async function onMessage(event: CustomEvent<string>) {
|
| 291 |
if (!data.shared) {
|
| 292 |
-
writeMessage(event.detail);
|
| 293 |
} else {
|
| 294 |
-
convFromShared()
|
| 295 |
.then(async (convId) => {
|
| 296 |
await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
|
| 297 |
})
|
| 298 |
-
.then(() => writeMessage(event.detail))
|
| 299 |
.finally(() => (loading = false));
|
| 300 |
}
|
| 301 |
}
|
| 302 |
|
| 303 |
async function onRetry(event: CustomEvent<{ id: Message["id"]; content: string }>) {
|
| 304 |
if (!data.shared) {
|
| 305 |
-
writeMessage(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
} else {
|
| 307 |
-
convFromShared()
|
| 308 |
.then(async (convId) => {
|
| 309 |
await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
|
| 310 |
})
|
| 311 |
-
.then(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
.finally(() => (loading = false));
|
| 313 |
}
|
| 314 |
}
|
| 315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
$: $page.params.id, (($isAborted = true), (loading = false));
|
| 317 |
$: title = data.conversations.find((conv) => conv.id === $page.params.id)?.title ?? data.title;
|
| 318 |
</script>
|
|
@@ -337,6 +387,7 @@
|
|
| 337 |
bind:files
|
| 338 |
on:message={onMessage}
|
| 339 |
on:retry={onRetry}
|
|
|
|
| 340 |
on:vote={(event) => voteMessage(event.detail.score, event.detail.id)}
|
| 341 |
on:share={() => shareConversation($page.params.id, data.title)}
|
| 342 |
on:stop={() => (($isAborted = true), (loading = false))}
|
|
|
|
| 64 |
}
|
| 65 |
}
|
| 66 |
// this function is used to send new message to the backends
|
| 67 |
+
async function writeMessage({
|
| 68 |
+
prompt,
|
| 69 |
+
messageId = randomUUID(),
|
| 70 |
+
isRetry = false,
|
| 71 |
+
isContinue = false,
|
| 72 |
+
}: {
|
| 73 |
+
prompt?: string;
|
| 74 |
+
messageId?: ReturnType<typeof randomUUID>;
|
| 75 |
+
isRetry?: boolean;
|
| 76 |
+
isContinue?: boolean;
|
| 77 |
+
}): Promise<void> {
|
| 78 |
try {
|
| 79 |
$isAborted = false;
|
| 80 |
loading = true;
|
|
|
|
| 82 |
|
| 83 |
// first we check if the messageId already exists, indicating a retry
|
| 84 |
|
| 85 |
+
let msgIndex = messages.findIndex((msg) => msg.id === messageId);
|
| 86 |
+
|
| 87 |
+
if (msgIndex === -1) {
|
| 88 |
+
msgIndex = messages.length - 1;
|
| 89 |
+
}
|
| 90 |
+
if (isRetry && messages[msgIndex].from === "assistant") {
|
| 91 |
+
throw new Error("Trying to retry a message that is not from user");
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
if (isContinue && messages[msgIndex].from === "user") {
|
| 95 |
+
throw new Error("Trying to continue a message that is not from assistant");
|
| 96 |
}
|
| 97 |
|
| 98 |
+
// const isNewMessage = !isRetry && !isContinue;
|
| 99 |
+
|
| 100 |
const module = await import("browser-image-resizer");
|
| 101 |
|
| 102 |
// currently, only IDEFICS is supported by TGI
|
|
|
|
| 115 |
);
|
| 116 |
|
| 117 |
// slice up to the point of the retry
|
| 118 |
+
if (isRetry) {
|
| 119 |
+
messages = [
|
| 120 |
+
...messages.slice(0, msgIndex),
|
| 121 |
+
{
|
| 122 |
+
from: "user",
|
| 123 |
+
content: messages[msgIndex].content,
|
| 124 |
+
id: messageId,
|
| 125 |
+
files: messages[msgIndex].files,
|
| 126 |
+
},
|
| 127 |
+
];
|
| 128 |
+
} else if (!isContinue) {
|
| 129 |
+
// or add a new message if its not a continue request
|
| 130 |
+
if (!prompt) {
|
| 131 |
+
throw new Error("Prompt is undefined");
|
| 132 |
+
}
|
| 133 |
+
messages = [
|
| 134 |
+
...messages,
|
| 135 |
+
{
|
| 136 |
+
from: "user",
|
| 137 |
+
content: prompt ?? "",
|
| 138 |
+
id: messageId,
|
| 139 |
+
files: resizedImages,
|
| 140 |
+
},
|
| 141 |
+
];
|
| 142 |
+
}
|
| 143 |
|
| 144 |
files = [];
|
| 145 |
|
|
|
|
| 147 |
method: "POST",
|
| 148 |
headers: { "Content-Type": "application/json" },
|
| 149 |
body: JSON.stringify({
|
| 150 |
+
inputs: prompt,
|
| 151 |
id: messageId,
|
| 152 |
is_retry: isRetry,
|
| 153 |
+
is_continue: isContinue,
|
| 154 |
web_search: $webSearchParameters.useSearch,
|
| 155 |
files: isRetry ? undefined : resizedImages,
|
| 156 |
}),
|
|
|
|
| 315 |
// only used in case of creating new conversations (from the parent POST endpoint)
|
| 316 |
if ($pendingMessage) {
|
| 317 |
files = $pendingMessage.files;
|
| 318 |
+
await writeMessage({ prompt: $pendingMessage.content });
|
| 319 |
$pendingMessage = undefined;
|
| 320 |
}
|
| 321 |
});
|
| 322 |
|
| 323 |
async function onMessage(event: CustomEvent<string>) {
|
| 324 |
if (!data.shared) {
|
| 325 |
+
await writeMessage({ prompt: event.detail });
|
| 326 |
} else {
|
| 327 |
+
await convFromShared()
|
| 328 |
.then(async (convId) => {
|
| 329 |
await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
|
| 330 |
})
|
| 331 |
+
.then(async () => await writeMessage({ prompt: event.detail }))
|
| 332 |
.finally(() => (loading = false));
|
| 333 |
}
|
| 334 |
}
|
| 335 |
|
| 336 |
async function onRetry(event: CustomEvent<{ id: Message["id"]; content: string }>) {
|
| 337 |
if (!data.shared) {
|
| 338 |
+
await writeMessage({
|
| 339 |
+
prompt: event.detail.content,
|
| 340 |
+
messageId: event.detail.id,
|
| 341 |
+
isRetry: true,
|
| 342 |
+
});
|
| 343 |
} else {
|
| 344 |
+
await convFromShared()
|
| 345 |
.then(async (convId) => {
|
| 346 |
await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
|
| 347 |
})
|
| 348 |
+
.then(
|
| 349 |
+
async () =>
|
| 350 |
+
await writeMessage({
|
| 351 |
+
prompt: event.detail.content,
|
| 352 |
+
messageId: event.detail.id,
|
| 353 |
+
isRetry: true,
|
| 354 |
+
})
|
| 355 |
+
)
|
| 356 |
.finally(() => (loading = false));
|
| 357 |
}
|
| 358 |
}
|
| 359 |
|
| 360 |
+
async function onContinue(event: CustomEvent<{ id: Message["id"] }>) {
|
| 361 |
+
if (!data.shared) {
|
| 362 |
+
writeMessage({ messageId: event.detail.id, isContinue: true });
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
$: $page.params.id, (($isAborted = true), (loading = false));
|
| 367 |
$: title = data.conversations.find((conv) => conv.id === $page.params.id)?.title ?? data.title;
|
| 368 |
</script>
|
|
|
|
| 387 |
bind:files
|
| 388 |
on:message={onMessage}
|
| 389 |
on:retry={onRetry}
|
| 390 |
+
on:continue={onContinue}
|
| 391 |
on:vote={(event) => voteMessage(event.detail.score, event.detail.id)}
|
| 392 |
on:share={() => shareConversation($page.params.id, data.title)}
|
| 393 |
on:stop={() => (($isAborted = true), (loading = false))}
|
src/routes/conversation/[id]/+server.ts
CHANGED
|
@@ -91,14 +91,16 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 91 |
const {
|
| 92 |
inputs: newPrompt,
|
| 93 |
id: messageId,
|
| 94 |
-
is_retry,
|
|
|
|
| 95 |
web_search: webSearch,
|
| 96 |
files: b64files,
|
| 97 |
} = z
|
| 98 |
.object({
|
| 99 |
-
inputs: z.string().trim().min(1),
|
| 100 |
id: z.optional(z.string().uuid()),
|
| 101 |
is_retry: z.optional(z.boolean()),
|
|
|
|
| 102 |
web_search: z.optional(z.boolean()),
|
| 103 |
files: z.optional(z.array(z.string())),
|
| 104 |
})
|
|
@@ -136,38 +138,50 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 136 |
hashes = await Promise.all(files.map(async (file) => await uploadFile(file, conv)));
|
| 137 |
}
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
// get the list of messages
|
| 140 |
// while checking for retries
|
| 141 |
let messages = (() => {
|
| 142 |
-
|
|
|
|
| 143 |
// if the message is a retry, replace the message and remove the messages after it
|
| 144 |
let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId);
|
|
|
|
| 145 |
if (retryMessageIdx === -1) {
|
| 146 |
retryMessageIdx = conv.messages.length;
|
| 147 |
}
|
|
|
|
| 148 |
return [
|
| 149 |
...conv.messages.slice(0, retryMessageIdx),
|
| 150 |
{
|
| 151 |
-
content:
|
| 152 |
from: "user",
|
| 153 |
id: messageId as Message["id"],
|
| 154 |
updatedAt: new Date(),
|
| 155 |
files: conv.messages[retryMessageIdx]?.files,
|
| 156 |
},
|
| 157 |
];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
} // else append the message at the bottom
|
| 159 |
-
|
| 160 |
-
return [
|
| 161 |
-
...conv.messages,
|
| 162 |
-
{
|
| 163 |
-
content: newPrompt,
|
| 164 |
-
from: "user",
|
| 165 |
-
id: (messageId as Message["id"]) || crypto.randomUUID(),
|
| 166 |
-
createdAt: new Date(),
|
| 167 |
-
updatedAt: new Date(),
|
| 168 |
-
files: hashes,
|
| 169 |
-
},
|
| 170 |
-
];
|
| 171 |
})() satisfies Message[];
|
| 172 |
|
| 173 |
await collections.conversations.updateOne(
|
|
@@ -183,10 +197,14 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 183 |
}
|
| 184 |
);
|
| 185 |
|
|
|
|
|
|
|
| 186 |
// we now build the stream
|
| 187 |
const stream = new ReadableStream({
|
| 188 |
async start(controller) {
|
| 189 |
-
const updates: MessageUpdate[] =
|
|
|
|
|
|
|
| 190 |
|
| 191 |
function update(newUpdate: MessageUpdate) {
|
| 192 |
if (newUpdate.type !== "stream") {
|
|
@@ -209,7 +227,7 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 209 |
const summarizeIfNeeded = (async () => {
|
| 210 |
if (conv.title === "New Chat" && messages.length === 1) {
|
| 211 |
try {
|
| 212 |
-
conv.title = (await summarize(
|
| 213 |
update({ type: "status", status: "title", message: conv.title });
|
| 214 |
} catch (e) {
|
| 215 |
console.error(e);
|
|
@@ -232,17 +250,22 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 232 |
|
| 233 |
let webSearchResults: WebSearch | undefined;
|
| 234 |
|
| 235 |
-
if (webSearch) {
|
| 236 |
-
webSearchResults = await runWebSearch(conv,
|
|
|
|
|
|
|
|
|
|
| 237 |
}
|
| 238 |
|
| 239 |
-
messages[messages.length - 1].webSearch = webSearchResults;
|
| 240 |
-
|
| 241 |
conv.messages = messages;
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
try {
|
| 244 |
const endpoint = await model.getEndpoint();
|
| 245 |
-
for await (const output of await endpoint({ conversation: conv })) {
|
| 246 |
// if not generated_text is here it means the generation is not done
|
| 247 |
if (!output.generated_text) {
|
| 248 |
// else we get the next token
|
|
@@ -292,7 +315,8 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 292 |
...messages.slice(0, -1),
|
| 293 |
{
|
| 294 |
...messages[messages.length - 1],
|
| 295 |
-
content: output.generated_text,
|
|
|
|
| 296 |
updates,
|
| 297 |
updatedAt: new Date(),
|
| 298 |
},
|
|
@@ -302,6 +326,7 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 302 |
} catch (e) {
|
| 303 |
update({ type: "status", status: "error", message: (e as Error).message });
|
| 304 |
}
|
|
|
|
| 305 |
await collections.conversations.updateOne(
|
| 306 |
{
|
| 307 |
_id: convId,
|
|
@@ -315,6 +340,9 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 315 |
}
|
| 316 |
);
|
| 317 |
|
|
|
|
|
|
|
|
|
|
| 318 |
update({
|
| 319 |
type: "finalAnswer",
|
| 320 |
text: messages[messages.length - 1].content,
|
|
@@ -324,18 +352,20 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 324 |
return;
|
| 325 |
},
|
| 326 |
async cancel() {
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
{
|
| 332 |
-
$set: {
|
| 333 |
-
messages,
|
| 334 |
-
title: conv.title,
|
| 335 |
-
updatedAt: new Date(),
|
| 336 |
},
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
},
|
| 340 |
});
|
| 341 |
|
|
|
|
| 91 |
const {
|
| 92 |
inputs: newPrompt,
|
| 93 |
id: messageId,
|
| 94 |
+
is_retry: isRetry,
|
| 95 |
+
is_continue: isContinue,
|
| 96 |
web_search: webSearch,
|
| 97 |
files: b64files,
|
| 98 |
} = z
|
| 99 |
.object({
|
| 100 |
+
inputs: z.optional(z.string().trim().min(1)),
|
| 101 |
id: z.optional(z.string().uuid()),
|
| 102 |
is_retry: z.optional(z.boolean()),
|
| 103 |
+
is_continue: z.optional(z.boolean()),
|
| 104 |
web_search: z.optional(z.boolean()),
|
| 105 |
files: z.optional(z.array(z.string())),
|
| 106 |
})
|
|
|
|
| 138 |
hashes = await Promise.all(files.map(async (file) => await uploadFile(file, conv)));
|
| 139 |
}
|
| 140 |
|
| 141 |
+
// can only call isContinue on the last message id
|
| 142 |
+
if (isContinue && conv.messages[conv.messages.length - 1].id !== messageId) {
|
| 143 |
+
throw error(400, "Can only continue the last message");
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
// get the list of messages
|
| 147 |
// while checking for retries
|
| 148 |
let messages = (() => {
|
| 149 |
+
// for retries we slice and rewrite the last user message
|
| 150 |
+
if (isRetry && messageId) {
|
| 151 |
// if the message is a retry, replace the message and remove the messages after it
|
| 152 |
let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId);
|
| 153 |
+
|
| 154 |
if (retryMessageIdx === -1) {
|
| 155 |
retryMessageIdx = conv.messages.length;
|
| 156 |
}
|
| 157 |
+
|
| 158 |
return [
|
| 159 |
...conv.messages.slice(0, retryMessageIdx),
|
| 160 |
{
|
| 161 |
+
content: conv.messages[retryMessageIdx]?.content,
|
| 162 |
from: "user",
|
| 163 |
id: messageId as Message["id"],
|
| 164 |
updatedAt: new Date(),
|
| 165 |
files: conv.messages[retryMessageIdx]?.files,
|
| 166 |
},
|
| 167 |
];
|
| 168 |
+
} else if (isContinue && messageId) {
|
| 169 |
+
// for continue we do nothing and expand the last assistant message
|
| 170 |
+
return conv.messages;
|
| 171 |
+
} else {
|
| 172 |
+
// in normal conversation we add an extra user message
|
| 173 |
+
return [
|
| 174 |
+
...conv.messages,
|
| 175 |
+
{
|
| 176 |
+
content: newPrompt ?? "",
|
| 177 |
+
from: "user",
|
| 178 |
+
id: (messageId as Message["id"]) || crypto.randomUUID(),
|
| 179 |
+
createdAt: new Date(),
|
| 180 |
+
updatedAt: new Date(),
|
| 181 |
+
files: hashes,
|
| 182 |
+
},
|
| 183 |
+
];
|
| 184 |
} // else append the message at the bottom
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
})() satisfies Message[];
|
| 186 |
|
| 187 |
await collections.conversations.updateOne(
|
|
|
|
| 197 |
}
|
| 198 |
);
|
| 199 |
|
| 200 |
+
let doneStreaming = false;
|
| 201 |
+
|
| 202 |
// we now build the stream
|
| 203 |
const stream = new ReadableStream({
|
| 204 |
async start(controller) {
|
| 205 |
+
const updates: MessageUpdate[] = isContinue
|
| 206 |
+
? conv.messages[conv.messages.length - 1].updates ?? []
|
| 207 |
+
: [];
|
| 208 |
|
| 209 |
function update(newUpdate: MessageUpdate) {
|
| 210 |
if (newUpdate.type !== "stream") {
|
|
|
|
| 227 |
const summarizeIfNeeded = (async () => {
|
| 228 |
if (conv.title === "New Chat" && messages.length === 1) {
|
| 229 |
try {
|
| 230 |
+
conv.title = (await summarize(messages[0].content)) ?? conv.title;
|
| 231 |
update({ type: "status", status: "title", message: conv.title });
|
| 232 |
} catch (e) {
|
| 233 |
console.error(e);
|
|
|
|
| 250 |
|
| 251 |
let webSearchResults: WebSearch | undefined;
|
| 252 |
|
| 253 |
+
if (webSearch && !isContinue) {
|
| 254 |
+
webSearchResults = await runWebSearch(conv, messages[messages.length - 1].content, update);
|
| 255 |
+
messages[messages.length - 1].webSearch = webSearchResults;
|
| 256 |
+
} else if (isContinue) {
|
| 257 |
+
webSearchResults = messages[messages.length - 1].webSearch;
|
| 258 |
}
|
| 259 |
|
|
|
|
|
|
|
| 260 |
conv.messages = messages;
|
| 261 |
|
| 262 |
+
const previousContent = isContinue
|
| 263 |
+
? conv.messages.find((message) => message.id === messageId)?.content ?? ""
|
| 264 |
+
: "";
|
| 265 |
+
|
| 266 |
try {
|
| 267 |
const endpoint = await model.getEndpoint();
|
| 268 |
+
for await (const output of await endpoint({ conversation: conv, continue: isContinue })) {
|
| 269 |
// if not generated_text is here it means the generation is not done
|
| 270 |
if (!output.generated_text) {
|
| 271 |
// else we get the next token
|
|
|
|
| 315 |
...messages.slice(0, -1),
|
| 316 |
{
|
| 317 |
...messages[messages.length - 1],
|
| 318 |
+
content: previousContent + output.generated_text,
|
| 319 |
+
interrupted: !output.token.special, // if its a special token it finished on its own, else it was interrupted
|
| 320 |
updates,
|
| 321 |
updatedAt: new Date(),
|
| 322 |
},
|
|
|
|
| 326 |
} catch (e) {
|
| 327 |
update({ type: "status", status: "error", message: (e as Error).message });
|
| 328 |
}
|
| 329 |
+
|
| 330 |
await collections.conversations.updateOne(
|
| 331 |
{
|
| 332 |
_id: convId,
|
|
|
|
| 340 |
}
|
| 341 |
);
|
| 342 |
|
| 343 |
+
// used to detect if cancel() is called bc of interrupt or just because the connection closes
|
| 344 |
+
doneStreaming = true;
|
| 345 |
+
|
| 346 |
update({
|
| 347 |
type: "finalAnswer",
|
| 348 |
text: messages[messages.length - 1].content,
|
|
|
|
| 352 |
return;
|
| 353 |
},
|
| 354 |
async cancel() {
|
| 355 |
+
if (!doneStreaming) {
|
| 356 |
+
await collections.conversations.updateOne(
|
| 357 |
+
{
|
| 358 |
+
_id: convId,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
},
|
| 360 |
+
{
|
| 361 |
+
$set: {
|
| 362 |
+
messages,
|
| 363 |
+
title: conv.title,
|
| 364 |
+
updatedAt: new Date(),
|
| 365 |
+
},
|
| 366 |
+
}
|
| 367 |
+
);
|
| 368 |
+
}
|
| 369 |
},
|
| 370 |
});
|
| 371 |
|