Adrien Denat coyotte508 HF staff commited on
Commit
2772555
1 Parent(s): 74815cb

Stop generation button (closes #86) (#88)

Browse files
src/lib/components/StopGeneratingBtn.svelte ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import CarbonPause from "~icons/carbon/pause-filled";
3
+
4
+ export let visible: boolean = false;
5
+ export let className = "";
6
+ </script>
7
+
8
+ <button
9
+ type="button"
10
+ on:click
11
+ class="absolute btn flex rounded-lg border py-1 px-3 shadow-sm bg-white dark:bg-gray-700 hover:bg-gray-100 dark:hover:bg-gray-600 dark:border-gray-600 transition-all
12
+ {className}
13
+ {visible ? 'opacity-100 visible' : 'opacity-0 invisible'}
14
+ "
15
+ >
16
+ <CarbonPause class="mr-1 -ml-1 w-[1.1875rem] h-[1.25rem] text-gray-400" /> Stop generating
17
+ </button>
src/lib/components/chat/ChatWindow.svelte CHANGED
@@ -3,10 +3,11 @@
3
  import { createEventDispatcher } from "svelte";
4
 
5
  import CarbonSendAltFilled from "~icons/carbon/send-alt-filled";
 
6
 
7
  import ChatMessages from "./ChatMessages.svelte";
8
  import ChatInput from "./ChatInput.svelte";
9
- import CarbonExport from "~icons/carbon/export";
10
  import { PUBLIC_MODEL_ID, PUBLIC_MODEL_NAME } from "$env/static/public";
11
 
12
  export let messages: Message[] = [];
@@ -16,7 +17,7 @@
16
 
17
  let message: string;
18
 
19
- const dispatch = createEventDispatcher<{ message: string; share: void }>();
20
 
21
  const handleSubmit = () => {
22
  if (loading) return;
@@ -28,8 +29,13 @@
28
  <div class="relative min-h-0 min-w-0">
29
  <ChatMessages {loading} {pending} {messages} on:message />
30
  <div
31
- class="flex flex-col pointer-events-none [&>*]:pointer-events-auto max-md:border-t dark:border-gray-800 items-center max-md:dark:bg-gray-900 max-md:bg-white bg-gradient-to-t from-white via-white/80 to-white/0 dark:from-gray-900 dark:via-gray-80 dark:to-gray-900/0 justify-center absolute inset-x-0 max-w-3xl xl:max-w-4xl mx-auto px-3.5 sm:px-5 bottom-0 py-4 md:py-8 w-full"
32
  >
 
 
 
 
 
33
  <form
34
  on:submit|preventDefault={handleSubmit}
35
  class="w-full relative flex items-center rounded-xl flex-1 max-w-4xl border bg-gray-100 focus-within:border-gray-300 dark:bg-gray-700 dark:border-gray-600 dark:focus-within:border-gray-500 "
 
3
  import { createEventDispatcher } from "svelte";
4
 
5
  import CarbonSendAltFilled from "~icons/carbon/send-alt-filled";
6
+ import CarbonExport from "~icons/carbon/export";
7
 
8
  import ChatMessages from "./ChatMessages.svelte";
9
  import ChatInput from "./ChatInput.svelte";
10
+ import StopGeneratingBtn from "../StopGeneratingBtn.svelte";
11
  import { PUBLIC_MODEL_ID, PUBLIC_MODEL_NAME } from "$env/static/public";
12
 
13
  export let messages: Message[] = [];
 
17
 
18
  let message: string;
19
 
20
+ const dispatch = createEventDispatcher<{ message: string; share: void; stop: void }>();
21
 
22
  const handleSubmit = () => {
23
  if (loading) return;
 
29
  <div class="relative min-h-0 min-w-0">
30
  <ChatMessages {loading} {pending} {messages} on:message />
31
  <div
32
+ class="flex flex-col pointer-events-none [&>*]:pointer-events-auto max-md:border-t dark:border-gray-800 items-center max-md:dark:bg-gray-900 max-md:bg-white bg-gradient-to-t from-white via-white/80 to-white/0 dark:from-gray-900 dark:via-gray-80 dark:to-gray-900/0 justify-center absolute inset-x-0 max-w-3xl xl:max-w-4xl mx-auto px-3.5 sm:px-5 bottom-0 py-4 md:py-8 w-full z-0"
33
  >
34
+ <StopGeneratingBtn
35
+ visible={loading}
36
+ className="right-5 mr-[1px] md:mr-0 md:right-7 top-6 md:top-10 z-10"
37
+ on:click={() => dispatch("stop")}
38
+ />
39
  <form
40
  on:submit|preventDefault={handleSubmit}
41
  class="w-full relative flex items-center rounded-xl flex-1 max-w-4xl border bg-gray-100 focus-within:border-gray-300 dark:bg-gray-700 dark:border-gray-600 dark:focus-within:border-gray-500 "
src/lib/server/abortedGenerations.ts ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Shouldn't be needed if we dove into sveltekit internals, see https://github.com/huggingface/chat-ui/pull/88#issuecomment-1523173850
2
+
3
+ import { setTimeout } from "node:timers/promises";
4
+ import { collections } from "./database";
5
+
6
+ let closed = false;
7
+ process.on("SIGINT", () => {
8
+ closed = true;
9
+ });
10
+
11
+ export let abortedGenerations: Map<string, Date> = new Map();
12
+
13
+ async function maintainAbortedGenerations() {
14
+ while (!closed) {
15
+ await setTimeout(1000);
16
+
17
+ try {
18
+ const aborts = await collections.abortedGenerations.find({}).sort({ createdAt: 1 }).toArray();
19
+
20
+ abortedGenerations = new Map(
21
+ aborts.map(({ conversationId, createdAt }) => [conversationId.toString(), createdAt])
22
+ );
23
+ } catch (err) {
24
+ console.error(err);
25
+ }
26
+ }
27
+ }
28
+
29
+ maintainAbortedGenerations();
src/lib/server/database.ts CHANGED
@@ -2,6 +2,7 @@ import { MONGODB_URL, MONGODB_DB_NAME } from "$env/static/private";
2
  import { MongoClient } from "mongodb";
3
  import type { Conversation } from "$lib/types/Conversation";
4
  import type { SharedConversation } from "$lib/types/SharedConversation";
 
5
 
6
  const client = new MongoClient(MONGODB_URL, {
7
  // directConnection: true
@@ -13,11 +14,14 @@ const db = client.db(MONGODB_DB_NAME);
13
 
14
  const conversations = db.collection<Conversation>("conversations");
15
  const sharedConversations = db.collection<SharedConversation>("sharedConversations");
 
16
 
17
  export { client, db };
18
- export const collections = { conversations, sharedConversations };
19
 
20
  client.on("open", () => {
21
  conversations.createIndex({ sessionId: 1, updatedAt: -1 });
 
 
22
  sharedConversations.createIndex({ hash: 1 }, { unique: true });
23
  });
 
2
  import { MongoClient } from "mongodb";
3
  import type { Conversation } from "$lib/types/Conversation";
4
  import type { SharedConversation } from "$lib/types/SharedConversation";
5
+ import type { AbortedGeneration } from "$lib/types/AbortedGeneration";
6
 
7
  const client = new MongoClient(MONGODB_URL, {
8
  // directConnection: true
 
14
 
15
  const conversations = db.collection<Conversation>("conversations");
16
  const sharedConversations = db.collection<SharedConversation>("sharedConversations");
17
+ const abortedGenerations = db.collection<AbortedGeneration>("abortedGenerations");
18
 
19
  export { client, db };
20
+ export const collections = { conversations, sharedConversations, abortedGenerations };
21
 
22
  client.on("open", () => {
23
  conversations.createIndex({ sessionId: 1, updatedAt: -1 });
24
+ abortedGenerations.createIndex({ updatedAt: 1 }, { expireAfterSeconds: 30 });
25
+ abortedGenerations.createIndex({ conversationId: 1 }, { unique: true });
26
  sharedConversations.createIndex({ hash: 1 }, { unique: true });
27
  });
src/lib/types/AbortedGeneration.ts ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ // Ideally shouldn't be needed, see https://github.com/huggingface/chat-ui/pull/88#issuecomment-1523173850
2
+
3
+ import type { Conversation } from "./Conversation";
4
+
5
+ export interface AbortedGeneration {
6
+ createdAt: Date;
7
+ updatedAt: Date;
8
+ conversationId: Conversation["_id"];
9
+ }
src/lib/utils/concatUint8Arrays.ts ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { sum } from "./sum";
2
+
3
+ export function concatUint8Arrays(arrays: Uint8Array[]): Uint8Array {
4
+ const totalLength = sum(arrays.map((a) => a.length));
5
+ const result = new Uint8Array(totalLength);
6
+ let offset = 0;
7
+ for (const array of arrays) {
8
+ result.set(array, offset);
9
+ offset += array.length;
10
+ }
11
+ return result;
12
+ }
src/routes/conversation/[id]/+page.svelte CHANGED
@@ -16,6 +16,7 @@
16
 
17
  let messages = data.messages;
18
  let lastLoadedMessages = data.messages;
 
19
 
20
  // Since we modify the messages array locally, we don't want to reset it if an old version is passed
21
  $: if (data.messages !== lastLoadedMessages) {
@@ -55,7 +56,24 @@
55
  for await (const data of response) {
56
  pending = false;
57
 
58
- if (!data || conversationId !== $page.params.id) break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  // final message
61
  if (data.generated_text) {
@@ -91,6 +109,7 @@
91
  if (!message.trim()) return;
92
 
93
  try {
 
94
  loading = true;
95
  pending = true;
96
 
@@ -130,4 +149,5 @@
130
  {messages}
131
  on:message={(message) => writeMessage(message.detail)}
132
  on:share={() => shareConversation($page.params.id, data.title)}
 
133
  />
 
16
 
17
  let messages = data.messages;
18
  let lastLoadedMessages = data.messages;
19
+ let isAborted = false;
20
 
21
  // Since we modify the messages array locally, we don't want to reset it if an old version is passed
22
  $: if (data.messages !== lastLoadedMessages) {
 
56
  for await (const data of response) {
57
  pending = false;
58
 
59
+ if (!data) {
60
+ break;
61
+ }
62
+
63
+ if (conversationId !== $page.params.id) {
64
+ fetch(`${base}/conversation/${conversationId}/stop-generating`, {
65
+ method: "POST",
66
+ }).catch(console.error);
67
+ break;
68
+ }
69
+
70
+ if (isAborted) {
71
+ isAborted = false;
72
+ fetch(`${base}/conversation/${conversationId}/stop-generating`, {
73
+ method: "POST",
74
+ }).catch(console.error);
75
+ break;
76
+ }
77
 
78
  // final message
79
  if (data.generated_text) {
 
109
  if (!message.trim()) return;
110
 
111
  try {
112
+ isAborted = false;
113
  loading = true;
114
  pending = true;
115
 
 
149
  {messages}
150
  on:message={(message) => writeMessage(message.detail)}
151
  on:share={() => shareConversation($page.params.id, data.title)}
152
+ on:stop={() => (isAborted = true)}
153
  />
src/routes/conversation/[id]/+server.ts CHANGED
@@ -1,18 +1,21 @@
1
  import { PUBLIC_SEP_TOKEN } from "$env/static/public";
2
  import { buildPrompt } from "$lib/buildPrompt.js";
 
3
  import { collections } from "$lib/server/database.js";
4
  import { modelEndpoint } from "$lib/server/modelEndpoint.js";
5
  import type { Message } from "$lib/types/Message.js";
 
6
  import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
7
- import { sum } from "$lib/utils/sum";
8
  import { trimPrefix } from "$lib/utils/trimPrefix.js";
9
  import { trimSuffix } from "$lib/utils/trimSuffix.js";
 
10
  import { error } from "@sveltejs/kit";
11
  import { ObjectId } from "mongodb";
12
 
13
  export async function POST({ request, fetch, locals, params }) {
14
  // todo: add validation on params.id
15
  const convId = new ObjectId(params.id);
 
16
 
17
  const conv = await collections.conversations.findOne({
18
  _id: convId,
@@ -31,6 +34,8 @@ export async function POST({ request, fetch, locals, params }) {
31
 
32
  const randomEndpoint = modelEndpoint();
33
 
 
 
34
  const resp = await fetch(randomEndpoint.endpoint, {
35
  headers: {
36
  "Content-Type": request.headers.get("Content-Type") ?? "application/json",
@@ -41,12 +46,13 @@ export async function POST({ request, fetch, locals, params }) {
41
  ...json,
42
  inputs: prompt,
43
  }),
 
44
  });
45
 
46
  const [stream1, stream2] = resp.body!.tee();
47
 
48
  async function saveMessage() {
49
- let generated_text = await parseGeneratedText(stream2);
50
 
51
  // We could also check if PUBLIC_ASSISTANT_MESSAGE_TOKEN is present and use it to slice the text
52
  if (generated_text.startsWith(prompt)) {
@@ -97,19 +103,41 @@ export async function DELETE({ locals, params }) {
97
  return new Response();
98
  }
99
 
100
- async function parseGeneratedText(stream: ReadableStream): Promise<string> {
 
 
 
 
 
101
  const inputs: Uint8Array[] = [];
102
  for await (const input of streamToAsyncIterable(stream)) {
103
  inputs.push(input);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  }
105
 
106
  // Merge inputs into a single Uint8Array
107
- const completeInput = new Uint8Array(sum(inputs.map((input) => input.length)));
108
- let offset = 0;
109
- for (const input of inputs) {
110
- completeInput.set(input, offset);
111
- offset += input.length;
112
- }
113
 
114
  // Get last line starting with "data:" and parse it as JSON to get the generated text
115
  const message = new TextDecoder().decode(completeInput);
 
1
  import { PUBLIC_SEP_TOKEN } from "$env/static/public";
2
  import { buildPrompt } from "$lib/buildPrompt.js";
3
+ import { abortedGenerations } from "$lib/server/abortedGenerations.js";
4
  import { collections } from "$lib/server/database.js";
5
  import { modelEndpoint } from "$lib/server/modelEndpoint.js";
6
  import type { Message } from "$lib/types/Message.js";
7
+ import { concatUint8Arrays } from "$lib/utils/concatUint8Arrays.js";
8
  import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
 
9
  import { trimPrefix } from "$lib/utils/trimPrefix.js";
10
  import { trimSuffix } from "$lib/utils/trimSuffix.js";
11
+ import type { TextGenerationStreamOutput } from "@huggingface/inference";
12
  import { error } from "@sveltejs/kit";
13
  import { ObjectId } from "mongodb";
14
 
15
  export async function POST({ request, fetch, locals, params }) {
16
  // todo: add validation on params.id
17
  const convId = new ObjectId(params.id);
18
+ const date = new Date();
19
 
20
  const conv = await collections.conversations.findOne({
21
  _id: convId,
 
34
 
35
  const randomEndpoint = modelEndpoint();
36
 
37
+ const abortController = new AbortController();
38
+
39
  const resp = await fetch(randomEndpoint.endpoint, {
40
  headers: {
41
  "Content-Type": request.headers.get("Content-Type") ?? "application/json",
 
46
  ...json,
47
  inputs: prompt,
48
  }),
49
+ signal: abortController.signal,
50
  });
51
 
52
  const [stream1, stream2] = resp.body!.tee();
53
 
54
  async function saveMessage() {
55
+ let generated_text = await parseGeneratedText(stream2, convId, date, abortController);
56
 
57
  // We could also check if PUBLIC_ASSISTANT_MESSAGE_TOKEN is present and use it to slice the text
58
  if (generated_text.startsWith(prompt)) {
 
103
  return new Response();
104
  }
105
 
106
+ async function parseGeneratedText(
107
+ stream: ReadableStream,
108
+ conversationId: ObjectId,
109
+ promptedAt: Date,
110
+ abortController: AbortController
111
+ ): Promise<string> {
112
  const inputs: Uint8Array[] = [];
113
  for await (const input of streamToAsyncIterable(stream)) {
114
  inputs.push(input);
115
+
116
+ const date = abortedGenerations.get(conversationId.toString());
117
+
118
+ if (date && date > promptedAt) {
119
+ abortController.abort("Cancelled by user");
120
+ const completeInput = concatUint8Arrays(inputs);
121
+
122
+ const lines = new TextDecoder()
123
+ .decode(completeInput)
124
+ .split("\n")
125
+ .filter((line) => line.startsWith("data:"));
126
+
127
+ const tokens = lines.map((line) => {
128
+ try {
129
+ const json: TextGenerationStreamOutput = JSON.parse(line.slice("data:".length));
130
+ return json.token.text;
131
+ } catch {
132
+ return "";
133
+ }
134
+ });
135
+ return tokens.join("");
136
+ }
137
  }
138
 
139
  // Merge inputs into a single Uint8Array
140
+ const completeInput = concatUint8Arrays(inputs);
 
 
 
 
 
141
 
142
  // Get last line starting with "data:" and parse it as JSON to get the generated text
143
  const message = new TextDecoder().decode(completeInput);
src/routes/conversation/[id]/stop-generating/+server.ts ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { collections } from "$lib/server/database";
2
+ import { error } from "@sveltejs/kit";
3
+ import { ObjectId } from "mongodb";
4
+
5
+ /**
6
+ * Ideally, we'd be able to detect the client-side abort, see https://github.com/huggingface/chat-ui/pull/88#issuecomment-1523173850
7
+ */
8
+ export async function POST({ params, locals }) {
9
+ const conversationId = new ObjectId(params.id);
10
+
11
+ const conversation = await collections.conversations.findOne({
12
+ _id: conversationId,
13
+ sessionId: locals.sessionId,
14
+ });
15
+
16
+ if (!conversation) {
17
+ throw error(404, "Conversation not found");
18
+ }
19
+
20
+ await collections.abortedGenerations.updateOne(
21
+ { conversationId },
22
+ { $set: { updatedAt: new Date() }, $setOnInsert: { createdAt: new Date() } },
23
+ { upsert: true }
24
+ );
25
+
26
+ return new Response();
27
+ }