File size: 2,736 Bytes
1b66f8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91e621a
 
 
1b66f8d
 
 
 
 
91e621a
 
1b66f8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91e621a
1b66f8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
<script lang="ts">
	import ChatWindow from '$lib/components/chat/ChatWindow.svelte';
	import { pendingMessage } from '$lib/stores/pendingMessage';
	import { onMount } from 'svelte';
	import type { PageData } from './$types';
	import { page } from '$app/stores';
	import {
		PUBLIC_ASSISTANT_MESSAGE_TOKEN,
		PUBLIC_SEP_TOKEN,
		PUBLIC_USER_MESSAGE_TOKEN
	} from '$env/static/public';
	import { HfInference } from '@huggingface/inference';

	export let data: PageData;

	$: messages = data.messages;

	const userToken = PUBLIC_USER_MESSAGE_TOKEN;
	const assistantToken = PUBLIC_ASSISTANT_MESSAGE_TOKEN;
	const sepToken = PUBLIC_SEP_TOKEN;

	const hf = new HfInference();
	const model = hf.endpoint(`${$page.url.origin}/api/conversation`);

	let loading = false;

	async function getTextGenerationStream(inputs: string) {
		const response = model.textGenerationStream(
			{
				inputs,
				parameters: {
					// Taken from https://huggingface.co/spaces/huggingface/open-assistant-private-testing/blob/main/app.py#L54
					// @ts-ignore
					stop: ['<|endoftext|>'],
					max_new_tokens: 1024,
					truncate: 1024,
					typical_p: 0.2
				}
			},
			{
				use_cache: false
			}
		);

		// Regex to check if the text finishes by "<" but is not a piece of code like "`<img>`"
		const endOfTextRegex = /(?<!`)<(?!`)/;

		for await (const data of response) {
			if (!data) break;

			if (!data.token.special) {
				const lastMessage = messages.at(-1);

				if (lastMessage?.from !== 'assistant') {
					// First token has a space at the beginning, trim it
					messages = [...messages, { from: 'assistant', content: data.token.text.trimStart() }];
				} else {
					const isEndOfText = endOfTextRegex.test(data.token.text);

					lastMessage.content += isEndOfText ? data.token.text.replace('<', '') : data.token.text;
					messages = [...messages];

					if (isEndOfText) break;
				}
			}
		}

		// todo: if everything went well, store message + response in DB
	}

	async function writeMessage(message: string) {
		if (!message.trim()) return;

		try {
			loading = true;

			messages = [...messages, { from: 'user', content: message }];

			const inputs =
				messages
					.map(
						(m) =>
							(m.from === 'user' ? userToken + m.content : assistantToken + m.content) +
							(m.content.endsWith(sepToken) ? '' : sepToken)
					)
					.join('') + assistantToken;

			await getTextGenerationStream(inputs);
		} finally {
			loading = false;
		}
	}

	onMount(async () => {
		if ($pendingMessage) {
			const val = $pendingMessage;
			$pendingMessage = '';

			if (messages.length === 0) {
				writeMessage(val);
			}
		}
	});
</script>

<ChatWindow disabled={loading} {messages} on:message={(message) => writeMessage(message.detail)} />