Thomas G. Lopes
IndexedDb & Structured output (#82)
1778c9e unverified
raw
history blame
5.65 kB
import type { ConversationEntityMembers } from "$lib/state/conversations.svelte";
import type { ChatCompletionInputMessage, GenerationParameters, PipelineType } from "@huggingface/tasks";
import {
HF_HUB_INFERENCE_PROXY_TEMPLATE,
type InferenceSnippet,
type ModelDataMinimal,
type SnippetInferenceProvider,
getModelInputSnippet,
stringifyGenerationConfig,
stringifyMessages,
} from "@huggingface/tasks";
export const snippetBasic = (
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider
): InferenceSnippet[] => {
if (provider !== "hf-inference") {
return [];
}
return [
{
client: "curl",
content: `\
curl https://router.huggingface.co/hf-inference/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}}' \\
-H 'Content-Type: application/json' \\
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
},
];
};
export const snippetTextGeneration = (
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider,
providerModelId?: string,
opts?: {
streaming?: boolean;
messages?: ChatCompletionInputMessage[];
temperature?: GenerationParameters["temperature"];
max_tokens?: GenerationParameters["max_tokens"];
top_p?: GenerationParameters["top_p"];
structured_output?: ConversationEntityMembers["structuredOutput"];
}
): InferenceSnippet[] => {
if (model.tags.includes("conversational")) {
const baseUrl =
provider === "hf-inference"
? `https://router.huggingface.co/hf-inference/models/${model.id}/v1/chat/completions`
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) + "/v1/chat/completions";
const modelId = providerModelId ?? model.id;
// Conversational model detected, so we display a code snippet that features the Messages API
const streaming = opts?.streaming ?? true;
const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
const messages = opts?.messages ?? exampleMessages;
const config = {
...(opts?.temperature ? { temperature: opts.temperature } : undefined),
max_tokens: opts?.max_tokens ?? 500,
...(opts?.top_p ? { top_p: opts.top_p } : undefined),
...(opts?.structured_output?.enabled
? {
response_format: JSON.stringify(
{
type: "json_schema",
json_schema: JSON.parse(opts.structured_output.schema ?? ""),
},
null,
6
),
}
: undefined),
};
return [
{
client: "curl",
content: `curl '${baseUrl}' \\
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}' \\
-H 'Content-Type: application/json' \\
--data '{
"model": "${modelId}",
"messages": ${stringifyMessages(messages, {
indent: "\t",
attributeKeyQuotes: true,
customContentEscaper: str => str.replace(/'/g, "'\\''"),
})},
${stringifyGenerationConfig(config, {
indent: "\n ",
attributeKeyQuotes: true,
attributeValueConnector: ": ",
})}
"stream": ${!!streaming}
}'`,
},
];
} else {
return snippetBasic(model, accessToken, provider);
}
};
export const snippetZeroShotClassification = (
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider
): InferenceSnippet[] => {
if (provider !== "hf-inference") {
return [];
}
return [
{
client: "curl",
content: `curl https://router.huggingface.co/hf-inference/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
-H 'Content-Type: application/json' \\
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
},
];
};
export const snippetFile = (
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider
): InferenceSnippet[] => {
if (provider !== "hf-inference") {
return [];
}
return [
{
client: "curl",
content: `curl https://router.huggingface.co/hf-inference/models/${model.id} \\
-X POST \\
--data-binary '@${getModelInputSnippet(model, true, true)}' \\
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
},
];
};
export const curlSnippets: Partial<
Record<
PipelineType,
(
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider,
providerModelId?: string,
opts?: Record<string, unknown>
) => InferenceSnippet[]
>
> = {
// Same order as in tasks/src/pipelines.ts
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
"question-answering": snippetBasic,
"zero-shot-classification": snippetZeroShotClassification,
"translation": snippetBasic,
"summarization": snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetTextGeneration,
"image-text-to-text": snippetTextGeneration,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"automatic-speech-recognition": snippetFile,
"text-to-image": snippetBasic,
"text-to-speech": snippetBasic,
"text-to-audio": snippetBasic,
"audio-to-audio": snippetFile,
"audio-classification": snippetFile,
"image-classification": snippetFile,
"image-to-text": snippetFile,
"object-detection": snippetFile,
"image-segmentation": snippetFile,
};
export function getCurlInferenceSnippet(
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider,
providerModelId?: string,
opts?: Record<string, unknown>
): InferenceSnippet[] {
return model.pipeline_tag && model.pipeline_tag in curlSnippets
? (curlSnippets[model.pipeline_tag]?.(model, accessToken, provider, providerModelId, opts) ?? [])
: [];
}