File size: 1,734 Bytes
94753b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import { isUrl } from "./isUrl";

/**
 * We want to make calls to the huggingface hub the least possible, eg if
 * someone is calling Inference Endpoints 1000 times per second, we don't want
 * to make 1000 calls to the hub to get the task name.
 */
const taskCache = new Map<string, { task: string; date: Date }>();
const CACHE_DURATION = 10 * 60 * 1000;
const MAX_CACHE_ITEMS = 1000;
export const HF_HUB_URL = "https://huggingface.co";

export interface DefaultTaskOptions {
	fetch?: typeof fetch;
}

/**
 * Get the default task. Use a LRU cache of 1000 items with 10 minutes expiration
 * to avoid making too many calls to the HF hub.
 *
 * @returns The default task for the model, or `null` if it was impossible to get it
 */
export async function getDefaultTask(
	model: string,
	accessToken: string | undefined,
	options?: DefaultTaskOptions
): Promise<string | null> {
	if (isUrl(model)) {
		return null;
	}

	const key = `${model}:${accessToken}`;
	let cachedTask = taskCache.get(key);

	if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
		taskCache.delete(key);
		cachedTask = undefined;
	}

	if (cachedTask === undefined) {
		const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
			headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
		})
			.then((resp) => resp.json())
			.then((json) => json.pipeline_tag)
			.catch(() => null);

		if (!modelTask) {
			return null;
		}

		cachedTask = { task: modelTask, date: new Date() };
		taskCache.set(key, { task: modelTask, date: new Date() });

		if (taskCache.size > MAX_CACHE_ITEMS) {
			taskCache.delete(taskCache.keys().next().value);
		}
	}

	return cachedTask.task;
}