Liam Dyer
commited on
feat: smooth and combine token output (#936)
Browse files* feat: smooth and combine token output
* fix: stop generating button not triggering message updates abort
- src/lib/utils/messageUpdates.ts +214 -0
- src/routes/conversation/[id]/+page.svelte +50 -103
src/lib/utils/messageUpdates.ts
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import type { MessageUpdate, TextStreamUpdate } from "$lib/types/MessageUpdate";
|
2 |
+
|
3 |
+
type MessageUpdateRequestOptions = {
|
4 |
+
base: string;
|
5 |
+
inputs?: string;
|
6 |
+
messageId?: string;
|
7 |
+
isRetry: boolean;
|
8 |
+
isContinue: boolean;
|
9 |
+
webSearch: boolean;
|
10 |
+
files?: string[];
|
11 |
+
};
|
12 |
+
export async function fetchMessageUpdates(
|
13 |
+
conversationId: string,
|
14 |
+
opts: MessageUpdateRequestOptions,
|
15 |
+
abortSignal: AbortSignal
|
16 |
+
): Promise<AsyncGenerator<MessageUpdate>> {
|
17 |
+
const abortController = new AbortController();
|
18 |
+
abortSignal.addEventListener("abort", () => abortController.abort());
|
19 |
+
|
20 |
+
const response = await fetch(`${opts.base}/conversation/${conversationId}`, {
|
21 |
+
method: "POST",
|
22 |
+
headers: { "Content-Type": "application/json" },
|
23 |
+
body: JSON.stringify({
|
24 |
+
inputs: opts.inputs,
|
25 |
+
id: opts.messageId,
|
26 |
+
is_retry: opts.isRetry,
|
27 |
+
is_continue: opts.isContinue,
|
28 |
+
web_search: opts.webSearch,
|
29 |
+
files: opts.files,
|
30 |
+
}),
|
31 |
+
signal: abortController.signal,
|
32 |
+
});
|
33 |
+
|
34 |
+
if (!response.ok) {
|
35 |
+
const errorMessage = await response
|
36 |
+
.json()
|
37 |
+
.then((obj) => obj.message)
|
38 |
+
.catch(() => `Request failed with status code ${response.status}: ${response.statusText}`);
|
39 |
+
throw Error(errorMessage);
|
40 |
+
}
|
41 |
+
if (!response.body) {
|
42 |
+
throw Error("Body not defined");
|
43 |
+
}
|
44 |
+
return smoothAsyncIterator(
|
45 |
+
streamMessageUpdatesToFullWords(endpointStreamToIterator(response, abortController))
|
46 |
+
);
|
47 |
+
}
|
48 |
+
|
49 |
+
async function* endpointStreamToIterator(
|
50 |
+
response: Response,
|
51 |
+
abortController: AbortController
|
52 |
+
): AsyncGenerator<MessageUpdate> {
|
53 |
+
const reader = response.body?.pipeThrough(new TextDecoderStream()).getReader();
|
54 |
+
if (!reader) throw Error("Response for endpoint had no body");
|
55 |
+
|
56 |
+
// Handle any cases where we must abort
|
57 |
+
reader.closed.then(() => abortController.abort());
|
58 |
+
|
59 |
+
// Handle logic for aborting
|
60 |
+
abortController.signal.addEventListener("abort", () => reader.cancel());
|
61 |
+
|
62 |
+
// ex) If the last response is => {"type": "stream", "token":
|
63 |
+
// It should be => {"type": "stream", "token": "Hello"} = prev_input_chunk + "Hello"}
|
64 |
+
let prevChunk = "";
|
65 |
+
while (!abortController.signal.aborted) {
|
66 |
+
const { done, value } = await reader.read();
|
67 |
+
if (done) {
|
68 |
+
abortController.abort();
|
69 |
+
break;
|
70 |
+
}
|
71 |
+
if (!value) continue;
|
72 |
+
|
73 |
+
const { messageUpdates, remainingText } = parseMessageUpdates(prevChunk + value);
|
74 |
+
prevChunk = remainingText;
|
75 |
+
for (const messageUpdate of messageUpdates) yield messageUpdate;
|
76 |
+
}
|
77 |
+
}
|
78 |
+
|
79 |
+
function parseMessageUpdates(value: string): {
|
80 |
+
messageUpdates: MessageUpdate[];
|
81 |
+
remainingText: string;
|
82 |
+
} {
|
83 |
+
const inputs = value.split("\n");
|
84 |
+
const messageUpdates: MessageUpdate[] = [];
|
85 |
+
for (const input of inputs) {
|
86 |
+
try {
|
87 |
+
messageUpdates.push(JSON.parse(input) as MessageUpdate);
|
88 |
+
} catch (error) {
|
89 |
+
// in case of parsing error, we return what we were able to parse
|
90 |
+
if (error instanceof SyntaxError) {
|
91 |
+
return {
|
92 |
+
messageUpdates,
|
93 |
+
remainingText: inputs.at(-1) ?? "",
|
94 |
+
};
|
95 |
+
}
|
96 |
+
}
|
97 |
+
}
|
98 |
+
return { messageUpdates, remainingText: "" };
|
99 |
+
}
|
100 |
+
|
101 |
+
/**
|
102 |
+
* Emits all the message updates immediately that aren't "stream" type
|
103 |
+
* Emits a concatenated "stream" type message update once it detects a full word
|
104 |
+
* Example: "what" " don" "'t" => "what" " don't"
|
105 |
+
* Only supports latin languages, ignores others
|
106 |
+
*/
|
107 |
+
async function* streamMessageUpdatesToFullWords(
|
108 |
+
iterator: AsyncGenerator<MessageUpdate>
|
109 |
+
): AsyncGenerator<MessageUpdate> {
|
110 |
+
let bufferedStreamUpdates: TextStreamUpdate[] = [];
|
111 |
+
|
112 |
+
const endAlphanumeric = /[a-zA-Z0-9À-ž'`]+$/;
|
113 |
+
const beginnningAlphanumeric = /^[a-zA-Z0-9À-ž'`]+/;
|
114 |
+
|
115 |
+
for await (const messageUpdate of iterator) {
|
116 |
+
if (messageUpdate.type !== "stream") {
|
117 |
+
yield messageUpdate;
|
118 |
+
continue;
|
119 |
+
}
|
120 |
+
bufferedStreamUpdates.push(messageUpdate);
|
121 |
+
|
122 |
+
let lastIndexEmitted = 0;
|
123 |
+
for (let i = 1; i < bufferedStreamUpdates.length; i++) {
|
124 |
+
const prevEndsAlphanumeric = endAlphanumeric.test(bufferedStreamUpdates[i - 1].token);
|
125 |
+
const currBeginsAlphanumeric = beginnningAlphanumeric.test(bufferedStreamUpdates[i].token);
|
126 |
+
const shouldCombine = prevEndsAlphanumeric && currBeginsAlphanumeric;
|
127 |
+
const combinedTooMany = i - lastIndexEmitted >= 5;
|
128 |
+
if (shouldCombine && !combinedTooMany) continue;
|
129 |
+
|
130 |
+
// Combine tokens together and emit
|
131 |
+
yield {
|
132 |
+
type: "stream",
|
133 |
+
token: bufferedStreamUpdates
|
134 |
+
.slice(lastIndexEmitted, i)
|
135 |
+
.map((_) => _.token)
|
136 |
+
.join(""),
|
137 |
+
};
|
138 |
+
lastIndexEmitted = i;
|
139 |
+
}
|
140 |
+
bufferedStreamUpdates = bufferedStreamUpdates.slice(lastIndexEmitted);
|
141 |
+
}
|
142 |
+
for (const messageUpdate of bufferedStreamUpdates) yield messageUpdate;
|
143 |
+
}
|
144 |
+
|
145 |
+
/**
|
146 |
+
* Attempts to smooth out the time between values emitted by an async iterator
|
147 |
+
* by waiting for the average time between values to emit the next value
|
148 |
+
*/
|
149 |
+
async function* smoothAsyncIterator<T>(iterator: AsyncGenerator<T>): AsyncGenerator<T> {
|
150 |
+
const eventTarget = new EventTarget();
|
151 |
+
let done = false;
|
152 |
+
const valuesBuffer: T[] = [];
|
153 |
+
const valueTimesMS: number[] = [];
|
154 |
+
|
155 |
+
const next = async () => {
|
156 |
+
const obj = await iterator.next();
|
157 |
+
if (obj.done) {
|
158 |
+
done = true;
|
159 |
+
} else {
|
160 |
+
valuesBuffer.push(obj.value);
|
161 |
+
valueTimesMS.push(performance.now());
|
162 |
+
next();
|
163 |
+
}
|
164 |
+
eventTarget.dispatchEvent(new Event("next"));
|
165 |
+
};
|
166 |
+
next();
|
167 |
+
|
168 |
+
let timeOfLastEmitMS = performance.now();
|
169 |
+
while (!done || valuesBuffer.length > 0) {
|
170 |
+
// Only consider the last X times between tokens
|
171 |
+
const sampledTimesMS = valueTimesMS.slice(-30);
|
172 |
+
|
173 |
+
// Get the total time spent in abnormal periods
|
174 |
+
const anomalyThresholdMS = 2000;
|
175 |
+
const anomalyDurationMS = sampledTimesMS
|
176 |
+
.map((time, i, times) => time - times[i - 1])
|
177 |
+
.slice(1)
|
178 |
+
.filter((time) => time > anomalyThresholdMS)
|
179 |
+
.reduce((a, b) => a + b, 0);
|
180 |
+
|
181 |
+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
182 |
+
const totalTimeMSBetweenValues = sampledTimesMS.at(-1)! - sampledTimesMS[0];
|
183 |
+
const timeMSBetweenValues = totalTimeMSBetweenValues - anomalyDurationMS;
|
184 |
+
|
185 |
+
const averageTimeMSBetweenValues = Math.min(
|
186 |
+
200,
|
187 |
+
timeMSBetweenValues / (sampledTimesMS.length - 1)
|
188 |
+
);
|
189 |
+
const timeSinceLastEmitMS = performance.now() - timeOfLastEmitMS;
|
190 |
+
|
191 |
+
// Emit after waiting duration or cancel if "next" event is emitted
|
192 |
+
const gotNext = await Promise.race([
|
193 |
+
sleep(Math.max(5, averageTimeMSBetweenValues - timeSinceLastEmitMS)),
|
194 |
+
waitForEvent(eventTarget, "next"),
|
195 |
+
]);
|
196 |
+
|
197 |
+
// Go to next iteration so we can re-calculate when to emit
|
198 |
+
if (gotNext) continue;
|
199 |
+
|
200 |
+
// Nothing in buffer to emit
|
201 |
+
if (valuesBuffer.length === 0) continue;
|
202 |
+
|
203 |
+
// Emit
|
204 |
+
timeOfLastEmitMS = performance.now();
|
205 |
+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
206 |
+
yield valuesBuffer.shift()!;
|
207 |
+
}
|
208 |
+
}
|
209 |
+
|
210 |
+
const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
211 |
+
const waitForEvent = (eventTarget: EventTarget, eventName: string) =>
|
212 |
+
new Promise<boolean>((resolve) =>
|
213 |
+
eventTarget.addEventListener(eventName, () => resolve(true), { once: true })
|
214 |
+
);
|
src/routes/conversation/[id]/+page.svelte
CHANGED
@@ -16,6 +16,7 @@
|
|
16 |
import file2base64 from "$lib/utils/file2base64";
|
17 |
import { addChildren } from "$lib/utils/tree/addChildren";
|
18 |
import { addSibling } from "$lib/utils/tree/addSibling";
|
|
|
19 |
import { createConvTreeStore } from "$lib/stores/convTree";
|
20 |
import type { v4 } from "uuid";
|
21 |
|
@@ -181,125 +182,71 @@
|
|
181 |
|
182 |
messages = [...messages];
|
183 |
const messageToWriteTo = messages.find((message) => message.id === messageToWriteToId);
|
184 |
-
|
185 |
if (!messageToWriteTo) {
|
186 |
throw new Error("Message to write to not found");
|
187 |
}
|
|
|
188 |
// disable websearch if assistant is present
|
189 |
const hasAssistant = !!$page.data.assistant;
|
190 |
-
|
191 |
-
const
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
inputs: prompt,
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
files: isRetry ? undefined : resizedImages,
|
201 |
-
}
|
|
|
|
|
|
|
202 |
});
|
|
|
203 |
|
204 |
files = [];
|
205 |
-
if (!response.body) {
|
206 |
-
throw new Error("Body not defined");
|
207 |
-
}
|
208 |
-
|
209 |
-
if (!response.ok) {
|
210 |
-
error.set((await response.json())?.message);
|
211 |
-
return;
|
212 |
-
}
|
213 |
|
214 |
-
// eslint-disable-next-line no-undef
|
215 |
-
const encoder = new TextDecoderStream();
|
216 |
-
const reader = response?.body?.pipeThrough(encoder).getReader();
|
217 |
-
let finalAnswer = "";
|
218 |
const messageUpdates: MessageUpdate[] = [];
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
let readerClosed = false;
|
229 |
-
|
230 |
-
reader.closed.then(() => {
|
231 |
-
readerClosed = true;
|
232 |
-
});
|
233 |
-
|
234 |
-
while (finalAnswer === "") {
|
235 |
-
// check for abort
|
236 |
-
if ($isAborted || $error || readerClosed) {
|
237 |
-
reader?.cancel();
|
238 |
break;
|
239 |
}
|
240 |
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
if (update.type !== "stream") {
|
261 |
-
messageUpdates.push(update);
|
262 |
-
}
|
263 |
-
|
264 |
-
if (update.type === "finalAnswer") {
|
265 |
-
finalAnswer = update.text;
|
266 |
-
loading = false;
|
267 |
-
pending = false;
|
268 |
-
} else if (update.type === "stream") {
|
269 |
-
pending = false;
|
270 |
-
messageToWriteTo.content += update.token;
|
271 |
-
messages = [...messages];
|
272 |
-
} else if (update.type === "webSearch") {
|
273 |
-
messageToWriteTo.updates = [...(messageToWriteTo.updates ?? []), update];
|
274 |
-
messages = [...messages];
|
275 |
-
} else if (update.type === "status") {
|
276 |
-
if (update.status === "title" && update.message) {
|
277 |
-
const convInData = data.conversations.find(({ id }) => id === $page.params.id);
|
278 |
-
if (convInData) {
|
279 |
-
convInData.title = update.message;
|
280 |
-
|
281 |
-
$titleUpdate = {
|
282 |
-
title: update.message,
|
283 |
-
convId: $page.params.id,
|
284 |
-
};
|
285 |
-
}
|
286 |
-
} else if (update.status === "error") {
|
287 |
-
$error = update.message ?? "An error has occurred";
|
288 |
-
}
|
289 |
-
} else if (update.type === "error") {
|
290 |
-
error.set(update.message);
|
291 |
-
reader.cancel();
|
292 |
-
}
|
293 |
-
} catch (parseError) {
|
294 |
-
// in case of parsing error we wait for the next message
|
295 |
-
|
296 |
-
if (el === inputs[inputs.length - 1]) {
|
297 |
-
prev_input_chunk.push(el);
|
298 |
-
}
|
299 |
-
return;
|
300 |
}
|
301 |
-
})
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
303 |
}
|
304 |
|
305 |
messageToWriteTo.updates = messageUpdates;
|
|
|
16 |
import file2base64 from "$lib/utils/file2base64";
|
17 |
import { addChildren } from "$lib/utils/tree/addChildren";
|
18 |
import { addSibling } from "$lib/utils/tree/addSibling";
|
19 |
+
import { fetchMessageUpdates } from "$lib/utils/messageUpdates";
|
20 |
import { createConvTreeStore } from "$lib/stores/convTree";
|
21 |
import type { v4 } from "uuid";
|
22 |
|
|
|
182 |
|
183 |
messages = [...messages];
|
184 |
const messageToWriteTo = messages.find((message) => message.id === messageToWriteToId);
|
|
|
185 |
if (!messageToWriteTo) {
|
186 |
throw new Error("Message to write to not found");
|
187 |
}
|
188 |
+
|
189 |
// disable websearch if assistant is present
|
190 |
const hasAssistant = !!$page.data.assistant;
|
191 |
+
const messageUpdatesAbortController = new AbortController();
|
192 |
+
const messageUpdatesIterator = await fetchMessageUpdates(
|
193 |
+
$page.params.id,
|
194 |
+
{
|
195 |
+
base,
|
196 |
inputs: prompt,
|
197 |
+
messageId,
|
198 |
+
isRetry,
|
199 |
+
isContinue,
|
200 |
+
webSearch: !hasAssistant && $webSearchParameters.useSearch,
|
201 |
files: isRetry ? undefined : resizedImages,
|
202 |
+
},
|
203 |
+
messageUpdatesAbortController.signal
|
204 |
+
).catch((err) => {
|
205 |
+
error.set(err.message);
|
206 |
});
|
207 |
+
if (messageUpdatesIterator === undefined) return;
|
208 |
|
209 |
files = [];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
|
|
|
|
|
|
|
|
211 |
const messageUpdates: MessageUpdate[] = [];
|
212 |
+
for await (const update of messageUpdatesIterator) {
|
213 |
+
if ($isAborted) {
|
214 |
+
messageUpdatesAbortController.abort();
|
215 |
+
return;
|
216 |
+
}
|
217 |
+
if (update.type === "finalAnswer") {
|
218 |
+
loading = false;
|
219 |
+
pending = false;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
break;
|
221 |
}
|
222 |
|
223 |
+
messageUpdates.push(update);
|
224 |
+
|
225 |
+
if (update.type === "stream") {
|
226 |
+
pending = false;
|
227 |
+
messageToWriteTo.content += update.token;
|
228 |
+
messages = [...messages];
|
229 |
+
} else if (update.type === "webSearch") {
|
230 |
+
messageToWriteTo.updates = [...(messageToWriteTo.updates ?? []), update];
|
231 |
+
messages = [...messages];
|
232 |
+
} else if (update.type === "status") {
|
233 |
+
if (update.status === "title" && update.message) {
|
234 |
+
const convInData = data.conversations.find(({ id }) => id === $page.params.id);
|
235 |
+
if (convInData) {
|
236 |
+
convInData.title = update.message;
|
237 |
+
|
238 |
+
$titleUpdate = {
|
239 |
+
title: update.message,
|
240 |
+
convId: $page.params.id,
|
241 |
+
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
}
|
243 |
+
} else if (update.status === "error") {
|
244 |
+
$error = update.message ?? "An error has occurred";
|
245 |
+
}
|
246 |
+
} else if (update.type === "error") {
|
247 |
+
error.set(update.message);
|
248 |
+
messageUpdatesAbortController.abort();
|
249 |
+
}
|
250 |
}
|
251 |
|
252 |
messageToWriteTo.updates = messageUpdates;
|