import type { PipelineType } from "../pipelines.js"; import { getModelInputSnippet } from "./inputs.js"; import type { ModelDataMinimal } from "./types.js"; export const snippetZeroShotClassification = (model: ModelDataMinimal): string => `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() output = query({ "inputs": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}, })`; export const snippetZeroShotImageClassification = (model: ModelDataMinimal): string => `def query(data): with open(data["image_path"], "rb") as f: img = f.read() payload={ "parameters": data["parameters"], "inputs": base64.b64encode(img).decode("utf-8") } response = requests.post(API_URL, headers=headers, json=payload) return response.json() output = query({ "image_path": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["cat", "dog", "llama"]}, })`; export const snippetBasic = (model: ModelDataMinimal): string => `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() output = query({ "inputs": ${getModelInputSnippet(model)}, })`; export const snippetFile = (model: ModelDataMinimal): string => `def query(filename): with open(filename, "rb") as f: data = f.read() response = requests.post(API_URL, headers=headers, data=data) return response.json() output = query(${getModelInputSnippet(model)})`; export const snippetTextToImage = (model: ModelDataMinimal): string => `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content image_bytes = query({ "inputs": ${getModelInputSnippet(model)}, }) # You can access the image with PIL.Image for example import io from PIL import Image image = Image.open(io.BytesIO(image_bytes))`; export const snippetTabular = (model: ModelDataMinimal): string => `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content response = query({ "inputs": {"data": ${getModelInputSnippet(model)}}, })`; export const snippetTextToAudio = (model: ModelDataMinimal): string => { // Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged // with the latest update to inference-api (IA). // Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate. if (model.library_name === "transformers") { return `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content audio_bytes = query({ "inputs": ${getModelInputSnippet(model)}, }) # You can access the audio with IPython.display for example from IPython.display import Audio Audio(audio_bytes)`; } else { return `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() audio, sampling_rate = query({ "inputs": ${getModelInputSnippet(model)}, }) # You can access the audio with IPython.display for example from IPython.display import Audio Audio(audio, rate=sampling_rate)`; } }; export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): string => `def query(payload): with open(payload["image"], "rb") as f: img = f.read() payload["image"] = base64.b64encode(img).decode("utf-8") response = requests.post(API_URL, headers=headers, json=payload) return response.json() output = query({ "inputs": ${getModelInputSnippet(model)}, })`; export const pythonSnippets: Partial string>> = { // Same order as in tasks/src/pipelines.ts "text-classification": snippetBasic, "token-classification": snippetBasic, "table-question-answering": snippetBasic, "question-answering": snippetBasic, "zero-shot-classification": snippetZeroShotClassification, translation: snippetBasic, summarization: snippetBasic, "feature-extraction": snippetBasic, "text-generation": snippetBasic, "text2text-generation": snippetBasic, "fill-mask": snippetBasic, "sentence-similarity": snippetBasic, "automatic-speech-recognition": snippetFile, "text-to-image": snippetTextToImage, "text-to-speech": snippetTextToAudio, "text-to-audio": snippetTextToAudio, "audio-to-audio": snippetFile, "audio-classification": snippetFile, "image-classification": snippetFile, "tabular-regression": snippetTabular, "tabular-classification": snippetTabular, "object-detection": snippetFile, "image-segmentation": snippetFile, "document-question-answering": snippetDocumentQuestionAnswering, "image-to-text": snippetFile, "zero-shot-image-classification": snippetZeroShotImageClassification, }; export function getPythonInferenceSnippet(model: ModelDataMinimal, accessToken: string): string { const body = model.pipeline_tag && model.pipeline_tag in pythonSnippets ? pythonSnippets[model.pipeline_tag]?.(model) ?? "" : ""; return `import requests API_URL = "https://api-inference.huggingface.co/models/${model.id}" headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}} ${body}`; } export function hasPythonInferenceSnippet(model: ModelDataMinimal): boolean { return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets; }