Spaces:
Running
Running
File size: 5,316 Bytes
b7b2c8c 447c0ca 7764421 cd6894d 2e6d1bb 2606dde 7764421 2606dde cf7ac8d 2606dde 12c3a5a 2606dde ce2231f 2606dde ce2231f 447c0ca f209301 447c0ca f209301 cd6894d 2606dde 447c0ca 2606dde 2e6d1bb 2606dde ad6275a 2606dde b7b2c8c 2606dde 7764421 2606dde f209301 447c0ca cf7ac8d 2606dde 7764421 2606dde b7b2c8c cd6894d 2e6d1bb 2606dde b7b2c8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import { HF_ACCESS_TOKEN, MODELS, OLD_MODELS } from "$env/static/private";
import type {
ChatTemplateInput,
WebSearchQueryTemplateInput,
WebSearchSummaryTemplateInput,
} from "$lib/types/Template";
import { compileTemplate } from "$lib/utils/template";
import { z } from "zod";
type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;
const sagemakerEndpoint = z.object({
host: z.literal("sagemaker"),
url: z.string().url(),
accessKey: z.string().min(1),
secretKey: z.string().min(1),
sessionToken: z.string().optional(),
});
const tgiEndpoint = z.object({
host: z.union([z.literal("tgi"), z.undefined()]),
url: z.string().url(),
authorization: z.string().min(1).default(`Bearer ${HF_ACCESS_TOKEN}`),
});
const commonEndpoint = z.object({
weight: z.number().int().positive().default(1),
});
const endpoint = z.lazy(() =>
z.union([sagemakerEndpoint.merge(commonEndpoint), tgiEndpoint.merge(commonEndpoint)])
);
const combinedEndpoint = endpoint.transform((data) => {
if (data.host === "tgi" || data.host === undefined) {
return tgiEndpoint.merge(commonEndpoint).parse(data);
} else if (data.host === "sagemaker") {
return sagemakerEndpoint.merge(commonEndpoint).parse(data);
} else {
throw new Error(`Invalid host: ${data.host}`);
}
});
const modelsRaw = z
.array(
z.object({
/** Used as an identifier in DB */
id: z.string().optional(),
/** Used to link to the model page, and for inference */
name: z.string().min(1),
displayName: z.string().min(1).optional(),
description: z.string().min(1).optional(),
websiteUrl: z.string().url().optional(),
modelUrl: z.string().url().optional(),
datasetName: z.string().min(1).optional(),
datasetUrl: z.string().url().optional(),
userMessageToken: z.string().default(""),
userMessageEndToken: z.string().default(""),
assistantMessageToken: z.string().default(""),
assistantMessageEndToken: z.string().default(""),
messageEndToken: z.string().default(""),
preprompt: z.string().min(1).optional(),
prepromptUrl: z.string().url().optional(),
chatPromptTemplate: z
.string()
.default(
"{{preprompt}}" +
"{{#each messages}}" +
"{{#ifUser}}{{@root.userMessageToken}}{{content}}{{@root.userMessageEndToken}}{{/ifUser}}" +
"{{#ifAssistant}}{{@root.assistantMessageToken}}{{content}}{{@root.assistantMessageEndToken}}{{/ifAssistant}}" +
"{{/each}}" +
"{{assistantMessageToken}}"
),
webSearchSummaryPromptTemplate: z
.string()
.default(
"{{userMessageToken}}{{answer}}{{userMessageEndToken}}" +
"{{userMessageToken}}" +
"The text above should be summarized to best answer the query: {{query}}." +
"{{userMessageEndToken}}" +
"{{assistantMessageToken}}Summary: "
),
webSearchQueryPromptTemplate: z
.string()
.default(
"{{userMessageToken}}" +
"The following messages were written by a user, trying to answer a question." +
"{{userMessageEndToken}}" +
"{{#each messages}}" +
"{{#ifUser}}{{@root.userMessageToken}}{{content}}{{@root.userMessageEndToken}}{{/ifUser}}" +
"{{/each}}" +
"{{userMessageToken}}" +
"What plain-text english sentence would you input into Google to answer the last question? Answer with a short (10 words max) simple sentence." +
"{{userMessageEndToken}}" +
"{{assistantMessageToken}}Query: "
),
promptExamples: z
.array(
z.object({
title: z.string().min(1),
prompt: z.string().min(1),
})
)
.optional(),
endpoints: z.array(combinedEndpoint).optional(),
parameters: z
.object({
temperature: z.number().min(0).max(1),
truncate: z.number().int().positive(),
max_new_tokens: z.number().int().positive(),
stop: z.array(z.string()).optional(),
})
.passthrough()
.optional(),
})
)
.parse(JSON.parse(MODELS));
export const models = await Promise.all(
modelsRaw.map(async (m) => ({
...m,
userMessageEndToken: m?.userMessageEndToken || m?.messageEndToken,
assistantMessageEndToken: m?.assistantMessageEndToken || m?.messageEndToken,
chatPromptRender: compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m),
webSearchSummaryPromptRender: compileTemplate<WebSearchSummaryTemplateInput>(
m.webSearchSummaryPromptTemplate,
m
),
webSearchQueryPromptRender: compileTemplate<WebSearchQueryTemplateInput>(
m.webSearchQueryPromptTemplate,
m
),
id: m.id || m.name,
displayName: m.displayName || m.name,
preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
}))
);
// Models that have been deprecated
export const oldModels = OLD_MODELS
? z
.array(
z.object({
id: z.string().optional(),
name: z.string().min(1),
displayName: z.string().min(1).optional(),
})
)
.parse(JSON.parse(OLD_MODELS))
.map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
: [];
export type BackendModel = Optional<(typeof models)[0], "preprompt">;
export type Endpoint = z.infer<typeof endpoint>;
export const defaultModel = models[0];
export const validateModel = (_models: BackendModel[]) => {
// Zod enum function requires 2 parameters
return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]);
};
|