nsarrazin HF staff commited on
Commit
9960338
·
unverified ·
1 Parent(s): afbf680

Refactor summarization so it gets called from backend (#456)

Browse files

* Refactor summarization

* get rid of debug log

* remove old todo

src/lib/server/summarize.ts ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { buildPrompt } from "$lib/buildPrompt";
2
+ import { generateFromDefaultEndpoint } from "$lib/server/generateFromDefaultEndpoint";
3
+ import { defaultModel } from "$lib/server/models";
4
+
5
+ export async function summarize(prompt: string) {
6
+ const userPrompt = `Please summarize the following message: \n` + prompt;
7
+
8
+ const summaryPrompt = await buildPrompt({
9
+ messages: [{ from: "user", content: userPrompt }],
10
+ preprompt:
11
+ "You are a summarization AI. Your task is to summarize user requests, in a single sentence of less than 5 words. Do not try to answer questions, just summarize the user's request.",
12
+ model: defaultModel,
13
+ });
14
+
15
+ const generated_text = await generateFromDefaultEndpoint(summaryPrompt);
16
+
17
+ if (generated_text) {
18
+ return generated_text;
19
+ }
20
+
21
+ return null;
22
+ }
src/routes/conversation/[id]/+page.svelte CHANGED
@@ -34,12 +34,6 @@
34
  let pending = false;
35
  let loginRequired = false;
36
 
37
- async function summarizeTitle(id: string) {
38
- await fetch(`${base}/conversation/${id}/summarize`, {
39
- method: "POST",
40
- });
41
- }
42
-
43
  // this function is used to send new message to the backends
44
  async function writeMessage(message: string, messageId = randomUUID()) {
45
  if (!message.trim()) return;
@@ -146,15 +140,7 @@
146
  // reset the websearchmessages
147
  webSearchMessages = [];
148
 
149
- // do title summarization
150
- // TODO: we should change this to wait until there is an assistant response.
151
- if (messages.filter((m) => m.from === "user").length === 1) {
152
- summarizeTitle($page.params.id)
153
- .then(() => invalidate(UrlDependency.ConversationList))
154
- .catch(console.error);
155
- } else {
156
- await invalidate(UrlDependency.ConversationList);
157
- }
158
  } catch (err) {
159
  if (err instanceof Error && err.message.includes("overloaded")) {
160
  $error = "Too much traffic, please try again.";
 
34
  let pending = false;
35
  let loginRequired = false;
36
 
 
 
 
 
 
 
37
  // this function is used to send new message to the backends
38
  async function writeMessage(message: string, messageId = randomUUID()) {
39
  if (!message.trim()) return;
 
140
  // reset the websearchmessages
141
  webSearchMessages = [];
142
 
143
+ await invalidate(UrlDependency.ConversationList);
 
 
 
 
 
 
 
 
144
  } catch (err) {
145
  if (err instanceof Error && err.message.includes("overloaded")) {
146
  $error = "Too much traffic, please try again.";
src/routes/conversation/[id]/+server.ts CHANGED
@@ -17,7 +17,8 @@ import { AwsClient } from "aws4fetch";
17
  import type { MessageUpdate } from "$lib/types/MessageUpdate";
18
  import { runWebSearch } from "$lib/server/websearch/runWebSearch";
19
  import type { WebSearch } from "$lib/types/WebSearch";
20
- import { abortedGenerations } from "$lib/server/abortedGenerations.js";
 
21
 
22
  export async function POST({ request, fetch, locals, params, getClientAddress }) {
23
  const id = z.string().parse(params.id);
@@ -167,6 +168,10 @@ export async function POST({ request, fetch, locals, params, getClientAddress })
167
  }
168
 
169
  async function saveLast(generated_text: string) {
 
 
 
 
170
  const lastMessage = messages[messages.length - 1];
171
 
172
  if (lastMessage) {
@@ -195,6 +200,7 @@ export async function POST({ request, fetch, locals, params, getClientAddress })
195
  {
196
  $set: {
197
  messages,
 
198
  updatedAt: new Date(),
199
  },
200
  }
@@ -277,6 +283,7 @@ export async function POST({ request, fetch, locals, params, getClientAddress })
277
  {
278
  $set: {
279
  messages,
 
280
  updatedAt: new Date(),
281
  },
282
  }
 
17
  import type { MessageUpdate } from "$lib/types/MessageUpdate";
18
  import { runWebSearch } from "$lib/server/websearch/runWebSearch";
19
  import type { WebSearch } from "$lib/types/WebSearch";
20
+ import { abortedGenerations } from "$lib/server/abortedGenerations";
21
+ import { summarize } from "$lib/server/summarize";
22
 
23
  export async function POST({ request, fetch, locals, params, getClientAddress }) {
24
  const id = z.string().parse(params.id);
 
168
  }
169
 
170
  async function saveLast(generated_text: string) {
171
+ if (!conv) {
172
+ throw new Error("Conversation not found");
173
+ }
174
+
175
  const lastMessage = messages[messages.length - 1];
176
 
177
  if (lastMessage) {
 
200
  {
201
  $set: {
202
  messages,
203
+ title: (await summarize(newPrompt)) ?? conv.title,
204
  updatedAt: new Date(),
205
  },
206
  }
 
283
  {
284
  $set: {
285
  messages,
286
+ title: (await summarize(newPrompt)) ?? conv.title,
287
  updatedAt: new Date(),
288
  },
289
  }
src/routes/conversation/[id]/summarize/+server.ts DELETED
@@ -1,74 +0,0 @@
1
- import { RATE_LIMIT } from "$env/static/private";
2
- import { buildPrompt } from "$lib/buildPrompt";
3
- import { authCondition } from "$lib/server/auth";
4
- import { collections } from "$lib/server/database";
5
- import { generateFromDefaultEndpoint } from "$lib/server/generateFromDefaultEndpoint";
6
- import { defaultModel } from "$lib/server/models";
7
- import { ERROR_MESSAGES } from "$lib/stores/errors";
8
- import { error } from "@sveltejs/kit";
9
- import { ObjectId } from "mongodb";
10
-
11
- export async function POST({ params, locals, getClientAddress }) {
12
- const convId = new ObjectId(params.id);
13
-
14
- const conversation = await collections.conversations.findOne({
15
- _id: convId,
16
- ...authCondition(locals),
17
- });
18
-
19
- if (!conversation) {
20
- throw error(404, "Conversation not found");
21
- }
22
-
23
- const userId = locals.user?._id ?? locals.sessionId;
24
-
25
- await collections.messageEvents.insertOne({
26
- userId: userId,
27
- createdAt: new Date(),
28
- ip: getClientAddress(),
29
- });
30
-
31
- const nEvents = Math.max(
32
- await collections.messageEvents.countDocuments({ userId }),
33
- await collections.messageEvents.countDocuments({ ip: getClientAddress() })
34
- );
35
-
36
- if (RATE_LIMIT != "" && nEvents > parseInt(RATE_LIMIT)) {
37
- throw error(429, ERROR_MESSAGES.rateLimited);
38
- }
39
-
40
- const firstMessage = conversation.messages.find((m) => m.from === "user");
41
-
42
- const userPrompt =
43
- `Please summarize the following message as a single sentence of less than 5 words:\n` +
44
- firstMessage?.content;
45
-
46
- const prompt = await buildPrompt({
47
- messages: [{ from: "user", content: userPrompt }],
48
- model: defaultModel,
49
- });
50
- const generated_text = await generateFromDefaultEndpoint(prompt);
51
-
52
- if (generated_text) {
53
- await collections.conversations.updateOne(
54
- {
55
- _id: convId,
56
- ...authCondition(locals),
57
- },
58
- {
59
- $set: { title: generated_text },
60
- }
61
- );
62
- }
63
-
64
- return new Response(
65
- JSON.stringify(
66
- generated_text
67
- ? {
68
- title: generated_text,
69
- }
70
- : {}
71
- ),
72
- { headers: { "Content-Type": "application/json" } }
73
- );
74
- }