File size: 2,724 Bytes
3cbea34
9db8ced
 
 
 
 
 
 
 
 
14f0244
3cbea34
 
 
 
9db8ced
 
a1afcb6
 
 
 
9db8ced
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import { HF_ACCESS_TOKEN, HF_TOKEN } from "$env/static/private";
import { buildPrompt } from "$lib/buildPrompt";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type { Endpoint } from "../endpoints";
import { z } from "zod";

export const endpointLlamacppParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("llamacpp"),
	url: z.string().url().default("http://127.0.0.1:8080"),
	accessToken: z
		.string()
		.min(1)
		.default(HF_TOKEN ?? HF_ACCESS_TOKEN),
});

export function endpointLlamacpp(
	input: z.input<typeof endpointLlamacppParametersSchema>
): Endpoint {
	const { url, model } = endpointLlamacppParametersSchema.parse(input);
	return async ({ conversation }) => {
		const prompt = await buildPrompt({
			messages: conversation.messages,
			webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
			preprompt: conversation.preprompt,
			model,
		});

		const r = await fetch(`${url}/completion`, {
			method: "POST",
			headers: {
				"Content-Type": "application/json",
			},
			body: JSON.stringify({
				prompt,
				stream: true,
				temperature: model.parameters.temperature,
				top_p: model.parameters.top_p,
				top_k: model.parameters.top_k,
				stop: model.parameters.stop,
				repeat_penalty: model.parameters.repetition_penalty,
				n_predict: model.parameters.max_new_tokens,
			}),
		});

		if (!r.ok) {
			throw new Error(`Failed to generate text: ${await r.text()}`);
		}

		const encoder = new TextDecoderStream();
		const reader = r.body?.pipeThrough(encoder).getReader();

		return (async function* () {
			let stop = false;
			let generatedText = "";
			let tokenId = 0;
			while (!stop) {
				// read the stream and log the outputs to console
				const out = (await reader?.read()) ?? { done: false, value: undefined };
				// we read, if it's done we cancel
				if (out.done) {
					reader?.cancel();
					return;
				}

				if (!out.value) {
					return;
				}

				if (out.value.startsWith("data: ")) {
					let data = null;
					try {
						data = JSON.parse(out.value.slice(6));
					} catch (e) {
						return;
					}
					if (data.content || data.stop) {
						generatedText += data.content;
						const output: TextGenerationStreamOutput = {
							token: {
								id: tokenId++,
								text: data.content ?? "",
								logprob: 0,
								special: false,
							},
							generated_text: data.stop ? generatedText : null,
							details: null,
						};
						if (data.stop) {
							stop = true;
							reader?.cancel();
						}
						yield output;
						// take the data.content value and yield it
					}
				}
			}
		})();
	};
}

export default endpointLlamacpp;