Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 1,734 Bytes
94753b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import { isUrl } from "./isUrl";
/**
* We want to make calls to the huggingface hub the least possible, eg if
* someone is calling Inference Endpoints 1000 times per second, we don't want
* to make 1000 calls to the hub to get the task name.
*/
const taskCache = new Map<string, { task: string; date: Date }>();
const CACHE_DURATION = 10 * 60 * 1000;
const MAX_CACHE_ITEMS = 1000;
export const HF_HUB_URL = "https://huggingface.co";
export interface DefaultTaskOptions {
fetch?: typeof fetch;
}
/**
* Get the default task. Use a LRU cache of 1000 items with 10 minutes expiration
* to avoid making too many calls to the HF hub.
*
* @returns The default task for the model, or `null` if it was impossible to get it
*/
export async function getDefaultTask(
model: string,
accessToken: string | undefined,
options?: DefaultTaskOptions
): Promise<string | null> {
if (isUrl(model)) {
return null;
}
const key = `${model}:${accessToken}`;
let cachedTask = taskCache.get(key);
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
taskCache.delete(key);
cachedTask = undefined;
}
if (cachedTask === undefined) {
const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
})
.then((resp) => resp.json())
.then((json) => json.pipeline_tag)
.catch(() => null);
if (!modelTask) {
return null;
}
cachedTask = { task: modelTask, date: new Date() };
taskCache.set(key, { task: modelTask, date: new Date() });
if (taskCache.size > MAX_CACHE_ITEMS) {
taskCache.delete(taskCache.keys().next().value);
}
}
return cachedTask.task;
}
|