File size: 3,863 Bytes
21dd449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { ApiModelInfo } from "../types/api/api-model";
import type { CredentialsParams, PipelineType } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { parseLinkHeader } from "../utils/parseLinkHeader";
import { pick } from "../utils/pick";

export const MODEL_EXPAND_KEYS = [
	"pipeline_tag",
	"private",
	"gated",
	"downloads",
	"likes",
	"lastModified",
] as const satisfies readonly (keyof ApiModelInfo)[];

export const MODEL_EXPANDABLE_KEYS = [
	"author",
	"cardData",
	"config",
	"createdAt",
	"disabled",
	"downloads",
	"downloadsAllTime",
	"gated",
	"gitalyUid",
	"inferenceProviderMapping",
	"lastModified",
	"library_name",
	"likes",
	"model-index",
	"pipeline_tag",
	"private",
	"safetensors",
	"sha",
	// "siblings",
	"spaces",
	"tags",
	"transformersInfo",
] as const satisfies readonly (keyof ApiModelInfo)[];

export interface ModelEntry {
	id: string;
	name: string;
	private: boolean;
	gated: false | "auto" | "manual";
	task?: PipelineType;
	likes: number;
	downloads: number;
	updatedAt: Date;
}

export async function* listModels<
	const T extends Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]> = never,
>(
	params?: {
		search?: {
			/**
			 * Will search in the model name for matches
			 */
			query?: string;
			owner?: string;
			task?: PipelineType;
			tags?: string[];
			/**
			 * Will search for models that have one of the inference providers in the list.
			 */
			inferenceProviders?: string[];
		};
		hubUrl?: string;
		additionalFields?: T[];
		/**
		 * Set to limit the number of models returned.
		 */
		limit?: number;
		/**
		 * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
		 */
		fetch?: typeof fetch;
	} & Partial<CredentialsParams>
): AsyncGenerator<ModelEntry & Pick<ApiModelInfo, T>> {
	const accessToken = params && checkCredentials(params);
	let totalToFetch = params?.limit ?? Infinity;
	const search = new URLSearchParams([
		...Object.entries({
			limit: String(Math.min(totalToFetch, 500)),
			...(params?.search?.owner ? { author: params.search.owner } : undefined),
			...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined),
			...(params?.search?.query ? { search: params.search.query } : undefined),
			...(params?.search?.inferenceProviders
				? { inference_provider: params.search.inferenceProviders.join(",") }
				: undefined),
		}),
		...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
		...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
		...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
	]).toString();
	let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`;

	while (url) {
		const res: Response = await (params?.fetch ?? fetch)(url, {
			headers: {
				accept: "application/json",
				...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined),
			},
		});

		if (!res.ok) {
			throw await createApiError(res);
		}

		const items: ApiModelInfo[] = await res.json();

		for (const item of items) {
			yield {
				...(params?.additionalFields && pick(item, params.additionalFields)),
				id: item._id,
				name: item.id,
				private: item.private,
				task: item.pipeline_tag,
				downloads: item.downloads,
				gated: item.gated,
				likes: item.likes,
				updatedAt: new Date(item.lastModified),
			} as ModelEntry & Pick<ApiModelInfo, T>;
			totalToFetch--;

			if (totalToFetch <= 0) {
				return;
			}
		}

		const linkHeader = res.headers.get("Link");

		url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
		// Could update url to reduce the limit if we don't need the whole 500 of the next batch.
	}
}