julien-c HF staff commited on
Commit
3e9f86e
1 Parent(s): 000ed50

Auto-naming convos by summarizing them (#58)

Browse files

* Auto-naming convos by summarizing them

* Ok let's do it only on first message then

* revamp

* location.reload for now

src/routes/+layout.svelte CHANGED
@@ -102,7 +102,7 @@
102
  ? 'bg-gray-100 dark:bg-gray-700'
103
  : ''}"
104
  >
105
- <div class="flex-1 truncate">azeza aze a ea zeazeazazeae</div>
106
 
107
  <button
108
  type="button"
 
102
  ? 'bg-gray-100 dark:bg-gray-700'
103
  : ''}"
104
  >
105
+ <div class="flex-1 truncate">{conv.title}</div>
106
 
107
  <button
108
  type="button"
src/routes/conversation/[id]/+page.server.ts CHANGED
@@ -1,6 +1,5 @@
1
  import type { PageServerLoad } from "./$types";
2
  import { collections } from "$lib/server/database";
3
- import type { Conversation } from "$lib/types/Conversation";
4
  import { ObjectId } from "mongodb";
5
  import { error } from "@sveltejs/kit";
6
 
 
1
  import type { PageServerLoad } from "./$types";
2
  import { collections } from "$lib/server/database";
 
3
  import { ObjectId } from "mongodb";
4
  import { error } from "@sveltejs/kit";
5
 
src/routes/conversation/[id]/+page.svelte CHANGED
@@ -6,6 +6,7 @@
6
  import { page } from "$app/stores";
7
  import { HfInference } from "@huggingface/inference";
8
  import { invalidate } from "$app/navigation";
 
9
 
10
  export let data: PageData;
11
 
@@ -58,6 +59,18 @@
58
  }
59
  }
60
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  async function writeMessage(message: string) {
62
  if (!message.trim()) return;
63
 
@@ -69,6 +82,10 @@
69
 
70
  await getTextGenerationStream(message);
71
 
 
 
 
 
72
  // Reload conversation order - doesn't seem to work
73
  // await invalidate('/');
74
  } finally {
 
6
  import { page } from "$app/stores";
7
  import { HfInference } from "@huggingface/inference";
8
  import { invalidate } from "$app/navigation";
9
+ import { base } from "$app/paths";
10
 
11
  export let data: PageData;
12
 
 
59
  }
60
  }
61
 
62
+ async function summarizeTitle(id: string) {
63
+ const response = await fetch(`${base}/conversation/${id}/summarize`, {
64
+ method: "POST",
65
+ });
66
+ if (response.ok) {
67
+ /// TODO(actually invalidate)
68
+ await invalidate("/");
69
+ await invalidate((url) => url.pathname === "/" || url.pathname === base);
70
+ location.reload();
71
+ }
72
+ }
73
+
74
  async function writeMessage(message: string) {
75
  if (!message.trim()) return;
76
 
 
82
 
83
  await getTextGenerationStream(message);
84
 
85
+ if (messages.filter((m) => m.from === "user").length === 1) {
86
+ summarizeTitle($page.params.id).catch(console.error);
87
+ }
88
+
89
  // Reload conversation order - doesn't seem to work
90
  // await invalidate('/');
91
  } finally {
src/routes/conversation/[id]/summarize/+server.ts ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HF_TOKEN } from "$env/static/private";
2
+ import { PUBLIC_MODEL_ENDPOINT } from "$env/static/public";
3
+ import { buildPrompt } from "$lib/buildPrompt";
4
+ import { collections } from "$lib/server/database.js";
5
+ import { error } from "@sveltejs/kit";
6
+ import { ObjectId } from "mongodb";
7
+
8
+ export async function POST({ params, locals }) {
9
+ const convId = new ObjectId(params.id);
10
+
11
+ const conversation = await collections.conversations.findOne({
12
+ _id: convId,
13
+ sessionId: locals.sessionId,
14
+ });
15
+
16
+ if (!conversation) {
17
+ throw error(404, "Conversation not found");
18
+ }
19
+
20
+ const firstMessage = conversation.messages.find((m) => m.from === "user");
21
+
22
+ const userPrompt =
23
+ `You are a summarizing assistant. Please summarize the following message as a single sentence of less than 5 words:\n` +
24
+ firstMessage?.content;
25
+
26
+ const prompt = buildPrompt([{ from: "user", content: userPrompt }]);
27
+
28
+ const resp = await fetch(PUBLIC_MODEL_ENDPOINT, {
29
+ headers: {
30
+ "Content-Type": "application/json",
31
+ Authorization: `Basic ${HF_TOKEN}`,
32
+ },
33
+ method: "POST",
34
+ body: JSON.stringify({
35
+ inputs: prompt,
36
+ parameters: {
37
+ temperature: 0.9,
38
+ top_p: 0.95,
39
+ repetition_penalty: 1.2,
40
+ top_k: 50,
41
+ watermark: false,
42
+ max_new_tokens: 1024,
43
+ stop: ["<|endoftext|>"],
44
+ return_full_text: false,
45
+ },
46
+ }),
47
+ });
48
+
49
+ const response = await resp.json();
50
+ let generatedTitle: string | undefined;
51
+ try {
52
+ if (typeof response[0].generated_text === "string") {
53
+ generatedTitle = response[0].generated_text;
54
+ }
55
+ } catch {
56
+ console.error("summarization failed");
57
+ }
58
+
59
+ if (generatedTitle) {
60
+ await collections.conversations.updateOne(
61
+ {
62
+ _id: convId,
63
+ sessionId: locals.sessionId,
64
+ },
65
+ {
66
+ $set: { title: generatedTitle },
67
+ }
68
+ );
69
+ }
70
+
71
+ return new Response(
72
+ JSON.stringify(
73
+ generatedTitle
74
+ ? {
75
+ title: generatedTitle,
76
+ }
77
+ : {}
78
+ ),
79
+ { headers: { "Content-Type": "application/json" } }
80
+ );
81
+ }