coyotte508's picture
coyotte508 HF staff
♻️ Simplify setup, prepare for multi-models (#156)
7764421 unverified
raw
history blame
6.44 kB
import { PUBLIC_SEP_TOKEN } from "$env/static/public";
import { buildPrompt } from "$lib/buildPrompt.js";
import { abortedGenerations } from "$lib/server/abortedGenerations.js";
import { collections } from "$lib/server/database.js";
import { modelEndpoint } from "$lib/server/modelEndpoint.js";
import { defaultModel, modelNames } from "$lib/server/models.js";
import type { Message } from "$lib/types/Message.js";
import { concatUint8Arrays } from "$lib/utils/concatUint8Arrays.js";
import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
import { trimPrefix } from "$lib/utils/trimPrefix.js";
import { trimSuffix } from "$lib/utils/trimSuffix.js";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { error } from "@sveltejs/kit";
import { ObjectId } from "mongodb";
import { z } from "zod";
export async function POST({ request, fetch, locals, params }) {
// todo: add validation on params.id
const convId = new ObjectId(params.id);
const date = new Date();
const conv = await collections.conversations.findOne({
_id: convId,
sessionId: locals.sessionId,
});
if (!conv) {
throw error(404, "Conversation not found");
}
const json = await request.json();
const {
inputs: newPrompt,
model,
options: { id: messageId, is_retry },
} = z
.object({
inputs: z.string().trim().min(1),
model: z
.enum([modelNames[0].name, ...modelNames.slice(1).map((m) => m.name)])
.default(defaultModel.name),
options: z.object({
id: z.optional(z.string().uuid()),
is_retry: z.optional(z.boolean()),
}),
})
.parse(json);
const messages = (() => {
if (is_retry && messageId) {
let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId);
if (retryMessageIdx === -1) {
retryMessageIdx = conv.messages.length;
}
return [
...conv.messages.slice(0, retryMessageIdx),
{ content: newPrompt, from: "user", id: messageId as Message["id"] },
];
}
return [
...conv.messages,
{ content: newPrompt, from: "user", id: (messageId as Message["id"]) || crypto.randomUUID() },
];
})() satisfies Message[];
// Todo: on-the-fly migration, remove later
for (const message of messages) {
if (!message.id) {
message.id = crypto.randomUUID();
}
}
const prompt = buildPrompt(messages);
const randomEndpoint = modelEndpoint(model);
const abortController = new AbortController();
const resp = await fetch(randomEndpoint.url, {
headers: {
"Content-Type": request.headers.get("Content-Type") ?? "application/json",
Authorization: randomEndpoint.authorization,
},
method: "POST",
body: JSON.stringify({
...json,
inputs: prompt,
}),
signal: abortController.signal,
});
if (!resp.body) {
throw new Error("Response body is empty");
}
const [stream1, stream2] = resp.body.tee();
async function saveMessage() {
let generated_text = await parseGeneratedText(stream2, convId, date, abortController);
// We could also check if PUBLIC_ASSISTANT_MESSAGE_TOKEN is present and use it to slice the text
if (generated_text.startsWith(prompt)) {
generated_text = generated_text.slice(prompt.length);
}
generated_text = trimSuffix(trimPrefix(generated_text, "<|startoftext|>"), PUBLIC_SEP_TOKEN);
messages.push({ from: "assistant", content: generated_text, id: crypto.randomUUID(), model });
await collections.conversations.updateOne(
{
_id: convId,
},
{
$set: {
messages,
updatedAt: new Date(),
},
}
);
}
saveMessage().catch(console.error);
// Todo: maybe we should wait for the message to be saved before ending the response - in case of errors
return new Response(stream1, {
headers: Object.fromEntries(resp.headers.entries()),
status: resp.status,
statusText: resp.statusText,
});
}
export async function DELETE({ locals, params }) {
const convId = new ObjectId(params.id);
const conv = await collections.conversations.findOne({
_id: convId,
sessionId: locals.sessionId,
});
if (!conv) {
throw error(404, "Conversation not found");
}
await collections.conversations.deleteOne({ _id: conv._id });
return new Response();
}
async function parseGeneratedText(
stream: ReadableStream,
conversationId: ObjectId,
promptedAt: Date,
abortController: AbortController
): Promise<string> {
const inputs: Uint8Array[] = [];
for await (const input of streamToAsyncIterable(stream)) {
inputs.push(input);
const date = abortedGenerations.get(conversationId.toString());
if (date && date > promptedAt) {
abortController.abort("Cancelled by user");
const completeInput = concatUint8Arrays(inputs);
const lines = new TextDecoder()
.decode(completeInput)
.split("\n")
.filter((line) => line.startsWith("data:"));
const tokens = lines.map((line) => {
try {
const json: TextGenerationStreamOutput = JSON.parse(line.slice("data:".length));
return json.token.text;
} catch {
return "";
}
});
return tokens.join("");
}
}
// Merge inputs into a single Uint8Array
const completeInput = concatUint8Arrays(inputs);
// Get last line starting with "data:" and parse it as JSON to get the generated text
const message = new TextDecoder().decode(completeInput);
let lastIndex = message.lastIndexOf("\ndata:");
if (lastIndex === -1) {
lastIndex = message.indexOf("data");
}
if (lastIndex === -1) {
console.error("Could not parse in last message");
}
let lastMessage = message.slice(lastIndex).trim().slice("data:".length);
if (lastMessage.includes("\n")) {
lastMessage = lastMessage.slice(0, lastMessage.indexOf("\n"));
}
const lastMessageJSON = JSON.parse(lastMessage);
if (lastMessageJSON.error) {
throw new Error(lastMessageJSON.error);
}
const res = lastMessageJSON.generated_text;
if (typeof res !== "string") {
throw new Error("Could not parse generated text");
}
return res;
}
export async function PATCH({ request, locals, params }) {
const { title } = z
.object({ title: z.string().trim().min(1).max(100) })
.parse(await request.json());
const convId = new ObjectId(params.id);
const conv = await collections.conversations.findOne({
_id: convId,
sessionId: locals.sessionId,
});
if (!conv) {
throw error(404, "Conversation not found");
}
await collections.conversations.updateOne(
{
_id: convId,
},
{
$set: {
title,
},
}
);
return new Response();
}