File size: 3,723 Bytes
9db8ced
 
 
 
 
 
cb29148
9db8ced
 
 
 
 
 
 
 
 
 
 
 
a1afcb6
 
 
 
9db8ced
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb29148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9db8ced
 
 
 
 
 
 
 
cb29148
 
9db8ced
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import { z } from "zod";
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
import { buildPrompt } from "$lib/buildPrompt";
import { OPENAI_API_KEY } from "$env/static/private";
import type { Endpoint } from "../endpoints";
import { format } from "date-fns";

export const endpointOAIParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("openai"),
	baseURL: z.string().url().default("https://api.openai.com/v1"),
	apiKey: z.string().default(OPENAI_API_KEY ?? "sk-"),
	completion: z
		.union([z.literal("completions"), z.literal("chat_completions")])
		.default("chat_completions"),
});

export async function endpointOai(
	input: z.input<typeof endpointOAIParametersSchema>
): Promise<Endpoint> {
	const { baseURL, apiKey, completion, model } = endpointOAIParametersSchema.parse(input);
	let OpenAI;
	try {
		OpenAI = (await import("openai")).OpenAI;
	} catch (e) {
		throw new Error("Failed to import OpenAI", { cause: e });
	}

	const openai = new OpenAI({
		apiKey: apiKey ?? "sk-",
		baseURL: baseURL,
	});

	if (completion === "completions") {
		return async ({ conversation }) => {
			return openAICompletionToTextGenerationStream(
				await openai.completions.create({
					model: model.id ?? model.name,
					prompt: await buildPrompt({
						messages: conversation.messages,
						webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
						preprompt: conversation.preprompt,
						model,
					}),
					stream: true,
					max_tokens: model.parameters?.max_new_tokens,
					stop: model.parameters?.stop,
					temperature: model.parameters?.temperature,
					top_p: model.parameters?.top_p,
					frequency_penalty: model.parameters?.repetition_penalty,
				})
			);
		};
	} else if (completion === "chat_completions") {
		return async ({ conversation }) => {
			let messages = conversation.messages;
			const webSearch = conversation.messages[conversation.messages.length - 1].webSearch;

			if (webSearch && webSearch.context) {
				const lastMsg = messages.slice(-1)[0];
				const messagesWithoutLastUsrMsg = messages.slice(0, -1);
				const previousUserMessages = messages.filter((el) => el.from === "user").slice(0, -1);

				const previousQuestions =
					previousUserMessages.length > 0
						? `Previous questions: \n${previousUserMessages
								.map(({ content }) => `- ${content}`)
								.join("\n")}`
						: "";
				const currentDate = format(new Date(), "MMMM d, yyyy");
				messages = [
					...messagesWithoutLastUsrMsg,
					{
						from: "user",
						content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results:
						=====================
						${webSearch.context}
						=====================
						${previousQuestions}
						Answer the question: ${lastMsg.content} 
						`,
					},
				];
			}

			const messagesOpenAI = messages.map((message) => ({
				role: message.from,
				content: message.content,
			}));

			return openAIChatToTextGenerationStream(
				await openai.chat.completions.create({
					model: model.id ?? model.name,
					messages: conversation.preprompt
						? [{ role: "system", content: conversation.preprompt }, ...messagesOpenAI]
						: messagesOpenAI,
					stream: true,
					max_tokens: model.parameters?.max_new_tokens,
					stop: model.parameters?.stop,
					temperature: model.parameters?.temperature,
					top_p: model.parameters?.top_p,
					frequency_penalty: model.parameters?.repetition_penalty,
				})
			);
		};
	} else {
		throw new Error("Invalid completion type");
	}
}