|
import { HUB_URL } from "../consts"; |
|
import { createApiError } from "../error"; |
|
import type { ApiModelInfo } from "../types/api/api-model"; |
|
import type { CredentialsParams, PipelineType } from "../types/public"; |
|
import { checkCredentials } from "../utils/checkCredentials"; |
|
import { parseLinkHeader } from "../utils/parseLinkHeader"; |
|
import { pick } from "../utils/pick"; |
|
|
|
export const MODEL_EXPAND_KEYS = [ |
|
"pipeline_tag", |
|
"private", |
|
"gated", |
|
"downloads", |
|
"likes", |
|
"lastModified", |
|
] as const satisfies readonly (keyof ApiModelInfo)[]; |
|
|
|
export const MODEL_EXPANDABLE_KEYS = [ |
|
"author", |
|
"cardData", |
|
"config", |
|
"createdAt", |
|
"disabled", |
|
"downloads", |
|
"downloadsAllTime", |
|
"gated", |
|
"gitalyUid", |
|
"inferenceProviderMapping", |
|
"lastModified", |
|
"library_name", |
|
"likes", |
|
"model-index", |
|
"pipeline_tag", |
|
"private", |
|
"safetensors", |
|
"sha", |
|
|
|
"spaces", |
|
"tags", |
|
"transformersInfo", |
|
] as const satisfies readonly (keyof ApiModelInfo)[]; |
|
|
|
export interface ModelEntry { |
|
id: string; |
|
name: string; |
|
private: boolean; |
|
gated: false | "auto" | "manual"; |
|
task?: PipelineType; |
|
likes: number; |
|
downloads: number; |
|
updatedAt: Date; |
|
} |
|
|
|
export async function* listModels< |
|
const T extends Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]> = never, |
|
>( |
|
params?: { |
|
search?: { |
|
|
|
|
|
|
|
query?: string; |
|
owner?: string; |
|
task?: PipelineType; |
|
tags?: string[]; |
|
|
|
|
|
|
|
inferenceProviders?: string[]; |
|
}; |
|
hubUrl?: string; |
|
additionalFields?: T[]; |
|
|
|
|
|
|
|
limit?: number; |
|
|
|
|
|
|
|
fetch?: typeof fetch; |
|
} & Partial<CredentialsParams> |
|
): AsyncGenerator<ModelEntry & Pick<ApiModelInfo, T>> { |
|
const accessToken = params && checkCredentials(params); |
|
let totalToFetch = params?.limit ?? Infinity; |
|
const search = new URLSearchParams([ |
|
...Object.entries({ |
|
limit: String(Math.min(totalToFetch, 500)), |
|
...(params?.search?.owner ? { author: params.search.owner } : undefined), |
|
...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined), |
|
...(params?.search?.query ? { search: params.search.query } : undefined), |
|
...(params?.search?.inferenceProviders |
|
? { inference_provider: params.search.inferenceProviders.join(",") } |
|
: undefined), |
|
}), |
|
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []), |
|
...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]), |
|
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []), |
|
]).toString(); |
|
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`; |
|
|
|
while (url) { |
|
const res: Response = await (params?.fetch ?? fetch)(url, { |
|
headers: { |
|
accept: "application/json", |
|
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined), |
|
}, |
|
}); |
|
|
|
if (!res.ok) { |
|
throw await createApiError(res); |
|
} |
|
|
|
const items: ApiModelInfo[] = await res.json(); |
|
|
|
for (const item of items) { |
|
yield { |
|
...(params?.additionalFields && pick(item, params.additionalFields)), |
|
id: item._id, |
|
name: item.id, |
|
private: item.private, |
|
task: item.pipeline_tag, |
|
downloads: item.downloads, |
|
gated: item.gated, |
|
likes: item.likes, |
|
updatedAt: new Date(item.lastModified), |
|
} as ModelEntry & Pick<ApiModelInfo, T>; |
|
totalToFetch--; |
|
|
|
if (totalToFetch <= 0) { |
|
return; |
|
} |
|
} |
|
|
|
const linkHeader = res.headers.get("Link"); |
|
|
|
url = linkHeader ? parseLinkHeader(linkHeader).next : undefined; |
|
|
|
} |
|
} |
|
|