Spaces:
Running
Running
File size: 6,766 Bytes
8583cf1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import type { MessageUpdate, TextStreamUpdate } from "$lib/types/MessageUpdate";
type MessageUpdateRequestOptions = {
base: string;
inputs?: string;
messageId?: string;
isRetry: boolean;
isContinue: boolean;
webSearch: boolean;
files?: string[];
};
export async function fetchMessageUpdates(
conversationId: string,
opts: MessageUpdateRequestOptions,
abortSignal: AbortSignal
): Promise<AsyncGenerator<MessageUpdate>> {
const abortController = new AbortController();
abortSignal.addEventListener("abort", () => abortController.abort());
const response = await fetch(`${opts.base}/conversation/${conversationId}`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
inputs: opts.inputs,
id: opts.messageId,
is_retry: opts.isRetry,
is_continue: opts.isContinue,
web_search: opts.webSearch,
files: opts.files,
}),
signal: abortController.signal,
});
if (!response.ok) {
const errorMessage = await response
.json()
.then((obj) => obj.message)
.catch(() => `Request failed with status code ${response.status}: ${response.statusText}`);
throw Error(errorMessage);
}
if (!response.body) {
throw Error("Body not defined");
}
return smoothAsyncIterator(
streamMessageUpdatesToFullWords(endpointStreamToIterator(response, abortController))
);
}
async function* endpointStreamToIterator(
response: Response,
abortController: AbortController
): AsyncGenerator<MessageUpdate> {
const reader = response.body?.pipeThrough(new TextDecoderStream()).getReader();
if (!reader) throw Error("Response for endpoint had no body");
// Handle any cases where we must abort
reader.closed.then(() => abortController.abort());
// Handle logic for aborting
abortController.signal.addEventListener("abort", () => reader.cancel());
// ex) If the last response is => {"type": "stream", "token":
// It should be => {"type": "stream", "token": "Hello"} = prev_input_chunk + "Hello"}
let prevChunk = "";
while (!abortController.signal.aborted) {
const { done, value } = await reader.read();
if (done) {
abortController.abort();
break;
}
if (!value) continue;
const { messageUpdates, remainingText } = parseMessageUpdates(prevChunk + value);
prevChunk = remainingText;
for (const messageUpdate of messageUpdates) yield messageUpdate;
}
}
function parseMessageUpdates(value: string): {
messageUpdates: MessageUpdate[];
remainingText: string;
} {
const inputs = value.split("\n");
const messageUpdates: MessageUpdate[] = [];
for (const input of inputs) {
try {
messageUpdates.push(JSON.parse(input) as MessageUpdate);
} catch (error) {
// in case of parsing error, we return what we were able to parse
if (error instanceof SyntaxError) {
return {
messageUpdates,
remainingText: inputs.at(-1) ?? "",
};
}
}
}
return { messageUpdates, remainingText: "" };
}
/**
* Emits all the message updates immediately that aren't "stream" type
* Emits a concatenated "stream" type message update once it detects a full word
* Example: "what" " don" "'t" => "what" " don't"
* Only supports latin languages, ignores others
*/
async function* streamMessageUpdatesToFullWords(
iterator: AsyncGenerator<MessageUpdate>
): AsyncGenerator<MessageUpdate> {
let bufferedStreamUpdates: TextStreamUpdate[] = [];
const endAlphanumeric = /[a-zA-Z0-9À-ž'`]+$/;
const beginnningAlphanumeric = /^[a-zA-Z0-9À-ž'`]+/;
for await (const messageUpdate of iterator) {
if (messageUpdate.type !== "stream") {
yield messageUpdate;
continue;
}
bufferedStreamUpdates.push(messageUpdate);
let lastIndexEmitted = 0;
for (let i = 1; i < bufferedStreamUpdates.length; i++) {
const prevEndsAlphanumeric = endAlphanumeric.test(bufferedStreamUpdates[i - 1].token);
const currBeginsAlphanumeric = beginnningAlphanumeric.test(bufferedStreamUpdates[i].token);
const shouldCombine = prevEndsAlphanumeric && currBeginsAlphanumeric;
const combinedTooMany = i - lastIndexEmitted >= 5;
if (shouldCombine && !combinedTooMany) continue;
// Combine tokens together and emit
yield {
type: "stream",
token: bufferedStreamUpdates
.slice(lastIndexEmitted, i)
.map((_) => _.token)
.join(""),
};
lastIndexEmitted = i;
}
bufferedStreamUpdates = bufferedStreamUpdates.slice(lastIndexEmitted);
}
for (const messageUpdate of bufferedStreamUpdates) yield messageUpdate;
}
/**
* Attempts to smooth out the time between values emitted by an async iterator
* by waiting for the average time between values to emit the next value
*/
async function* smoothAsyncIterator<T>(iterator: AsyncGenerator<T>): AsyncGenerator<T> {
const eventTarget = new EventTarget();
let done = false;
const valuesBuffer: T[] = [];
const valueTimesMS: number[] = [];
const next = async () => {
const obj = await iterator.next();
if (obj.done) {
done = true;
} else {
valuesBuffer.push(obj.value);
valueTimesMS.push(performance.now());
next();
}
eventTarget.dispatchEvent(new Event("next"));
};
next();
let timeOfLastEmitMS = performance.now();
while (!done || valuesBuffer.length > 0) {
// Only consider the last X times between tokens
const sampledTimesMS = valueTimesMS.slice(-30);
// Get the total time spent in abnormal periods
const anomalyThresholdMS = 2000;
const anomalyDurationMS = sampledTimesMS
.map((time, i, times) => time - times[i - 1])
.slice(1)
.filter((time) => time > anomalyThresholdMS)
.reduce((a, b) => a + b, 0);
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const totalTimeMSBetweenValues = sampledTimesMS.at(-1)! - sampledTimesMS[0];
const timeMSBetweenValues = totalTimeMSBetweenValues - anomalyDurationMS;
const averageTimeMSBetweenValues = Math.min(
200,
timeMSBetweenValues / (sampledTimesMS.length - 1)
);
const timeSinceLastEmitMS = performance.now() - timeOfLastEmitMS;
// Emit after waiting duration or cancel if "next" event is emitted
const gotNext = await Promise.race([
sleep(Math.max(5, averageTimeMSBetweenValues - timeSinceLastEmitMS)),
waitForEvent(eventTarget, "next"),
]);
// Go to next iteration so we can re-calculate when to emit
if (gotNext) continue;
// Nothing in buffer to emit
if (valuesBuffer.length === 0) continue;
// Emit
timeOfLastEmitMS = performance.now();
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
yield valuesBuffer.shift()!;
}
}
const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
const waitForEvent = (eventTarget: EventTarget, eventName: string) =>
new Promise<boolean>((resolve) =>
eventTarget.addEventListener(eventName, () => resolve(true), { once: true })
);
|