machineuser
Sync widgets demo
5ee3e16
raw
history blame contribute delete
No virus
5.34 kB
import type { PipelineType } from "../pipelines.js";
import { getModelInputSnippet } from "./inputs.js";
import type { ModelDataMinimal } from "./types.js";
export const snippetZeroShotClassification = (model: ModelDataMinimal): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
output = query({
"inputs": ${getModelInputSnippet(model)},
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
})`;
export const snippetZeroShotImageClassification = (model: ModelDataMinimal): string =>
`def query(data):
with open(data["image_path"], "rb") as f:
img = f.read()
payload={
"parameters": data["parameters"],
"inputs": base64.b64encode(img).decode("utf-8")
}
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
output = query({
"image_path": ${getModelInputSnippet(model)},
"parameters": {"candidate_labels": ["cat", "dog", "llama"]},
})`;
export const snippetBasic = (model: ModelDataMinimal): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
output = query({
"inputs": ${getModelInputSnippet(model)},
})`;
export const snippetFile = (model: ModelDataMinimal): string =>
`def query(filename):
with open(filename, "rb") as f:
data = f.read()
response = requests.post(API_URL, headers=headers, data=data)
return response.json()
output = query(${getModelInputSnippet(model)})`;
export const snippetTextToImage = (model: ModelDataMinimal): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
image_bytes = query({
"inputs": ${getModelInputSnippet(model)},
})
# You can access the image with PIL.Image for example
import io
from PIL import Image
image = Image.open(io.BytesIO(image_bytes))`;
export const snippetTabular = (model: ModelDataMinimal): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
response = query({
"inputs": {"data": ${getModelInputSnippet(model)}},
})`;
export const snippetTextToAudio = (model: ModelDataMinimal): string => {
// Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
// with the latest update to inference-api (IA).
// Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate.
if (model.library_name === "transformers") {
return `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
audio_bytes = query({
"inputs": ${getModelInputSnippet(model)},
})
# You can access the audio with IPython.display for example
from IPython.display import Audio
Audio(audio_bytes)`;
} else {
return `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
audio, sampling_rate = query({
"inputs": ${getModelInputSnippet(model)},
})
# You can access the audio with IPython.display for example
from IPython.display import Audio
Audio(audio, rate=sampling_rate)`;
}
};
export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): string =>
`def query(payload):
with open(payload["image"], "rb") as f:
img = f.read()
payload["image"] = base64.b64encode(img).decode("utf-8")
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
output = query({
"inputs": ${getModelInputSnippet(model)},
})`;
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal) => string>> = {
// Same order as in tasks/src/pipelines.ts
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
"question-answering": snippetBasic,
"zero-shot-classification": snippetZeroShotClassification,
translation: snippetBasic,
summarization: snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"automatic-speech-recognition": snippetFile,
"text-to-image": snippetTextToImage,
"text-to-speech": snippetTextToAudio,
"text-to-audio": snippetTextToAudio,
"audio-to-audio": snippetFile,
"audio-classification": snippetFile,
"image-classification": snippetFile,
"tabular-regression": snippetTabular,
"tabular-classification": snippetTabular,
"object-detection": snippetFile,
"image-segmentation": snippetFile,
"document-question-answering": snippetDocumentQuestionAnswering,
"image-to-text": snippetFile,
"zero-shot-image-classification": snippetZeroShotImageClassification,
};
export function getPythonInferenceSnippet(model: ModelDataMinimal, accessToken: string): string {
const body =
model.pipeline_tag && model.pipeline_tag in pythonSnippets ? pythonSnippets[model.pipeline_tag]?.(model) ?? "" : "";
return `import requests
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
${body}`;
}
export function hasPythonInferenceSnippet(model: ModelDataMinimal): boolean {
return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets;
}