File size: 2,968 Bytes
9db8ced
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73a5c0d
7c4d92a
9db8ced
 
a1afcb6
 
 
7c4d92a
73a5c0d
9db8ced
 
 
 
 
 
 
 
 
d5559df
73a5c0d
7c4d92a
9db8ced
 
 
d4016bc
e6addfc
 
 
 
 
 
 
d4016bc
 
9db8ced
 
 
e6addfc
9db8ced
d4016bc
 
 
 
 
9db8ced
 
 
 
d4016bc
e6addfc
9db8ced
 
 
 
e6addfc
bcfa394
 
 
 
 
e6addfc
 
d4016bc
 
9db8ced
 
 
e6addfc
9db8ced
d4016bc
 
 
 
 
9db8ced
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import { z } from "zod";
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
import { buildPrompt } from "$lib/buildPrompt";
import { OPENAI_API_KEY } from "$env/static/private";
import type { Endpoint } from "../endpoints";

export const endpointOAIParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("openai"),
	baseURL: z.string().url().default("https://api.openai.com/v1"),
	apiKey: z.string().default(OPENAI_API_KEY ?? "sk-"),
	completion: z
		.union([z.literal("completions"), z.literal("chat_completions")])
		.default("chat_completions"),
	defaultHeaders: z.record(z.string()).optional(),
	defaultQuery: z.record(z.string()).optional(),
});

export async function endpointOai(
	input: z.input<typeof endpointOAIParametersSchema>
): Promise<Endpoint> {
	const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery } =
		endpointOAIParametersSchema.parse(input);
	let OpenAI;
	try {
		OpenAI = (await import("openai")).OpenAI;
	} catch (e) {
		throw new Error("Failed to import OpenAI", { cause: e });
	}

	const openai = new OpenAI({
		apiKey: apiKey ?? "sk-",
		baseURL,
		defaultHeaders,
		defaultQuery,
	});

	if (completion === "completions") {
		return async ({ messages, preprompt, continueMessage, generateSettings }) => {
			const prompt = await buildPrompt({
				messages,
				continueMessage,
				preprompt,
				model,
			});

			const parameters = { ...model.parameters, ...generateSettings };

			return openAICompletionToTextGenerationStream(
				await openai.completions.create({
					model: model.id ?? model.name,
					prompt,
					stream: true,
					max_tokens: parameters?.max_new_tokens,
					stop: parameters?.stop,
					temperature: parameters?.temperature,
					top_p: parameters?.top_p,
					frequency_penalty: parameters?.repetition_penalty,
				})
			);
		};
	} else if (completion === "chat_completions") {
		return async ({ messages, preprompt, generateSettings }) => {
			let messagesOpenAI = messages.map((message) => ({
				role: message.from,
				content: message.content,
			}));

			if (messagesOpenAI?.[0]?.role !== "system") {
				messagesOpenAI = [{ role: "system", content: "" }, ...messagesOpenAI];
			}

			if (messagesOpenAI?.[0]) {
				messagesOpenAI[0].content = preprompt ?? "";
			}

			const parameters = { ...model.parameters, ...generateSettings };

			return openAIChatToTextGenerationStream(
				await openai.chat.completions.create({
					model: model.id ?? model.name,
					messages: messagesOpenAI,
					stream: true,
					max_tokens: parameters?.max_new_tokens,
					stop: parameters?.stop,
					temperature: parameters?.temperature,
					top_p: parameters?.top_p,
					frequency_penalty: parameters?.repetition_penalty,
				})
			);
		};
	} else {
		throw new Error("Invalid completion type");
	}
}