File size: 3,863 Bytes
21dd449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { ApiModelInfo } from "../types/api/api-model";
import type { CredentialsParams, PipelineType } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { parseLinkHeader } from "../utils/parseLinkHeader";
import { pick } from "../utils/pick";
export const MODEL_EXPAND_KEYS = [
"pipeline_tag",
"private",
"gated",
"downloads",
"likes",
"lastModified",
] as const satisfies readonly (keyof ApiModelInfo)[];
export const MODEL_EXPANDABLE_KEYS = [
"author",
"cardData",
"config",
"createdAt",
"disabled",
"downloads",
"downloadsAllTime",
"gated",
"gitalyUid",
"inferenceProviderMapping",
"lastModified",
"library_name",
"likes",
"model-index",
"pipeline_tag",
"private",
"safetensors",
"sha",
// "siblings",
"spaces",
"tags",
"transformersInfo",
] as const satisfies readonly (keyof ApiModelInfo)[];
export interface ModelEntry {
id: string;
name: string;
private: boolean;
gated: false | "auto" | "manual";
task?: PipelineType;
likes: number;
downloads: number;
updatedAt: Date;
}
export async function* listModels<
const T extends Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]> = never,
>(
params?: {
search?: {
/**
* Will search in the model name for matches
*/
query?: string;
owner?: string;
task?: PipelineType;
tags?: string[];
/**
* Will search for models that have one of the inference providers in the list.
*/
inferenceProviders?: string[];
};
hubUrl?: string;
additionalFields?: T[];
/**
* Set to limit the number of models returned.
*/
limit?: number;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): AsyncGenerator<ModelEntry & Pick<ApiModelInfo, T>> {
const accessToken = params && checkCredentials(params);
let totalToFetch = params?.limit ?? Infinity;
const search = new URLSearchParams([
...Object.entries({
limit: String(Math.min(totalToFetch, 500)),
...(params?.search?.owner ? { author: params.search.owner } : undefined),
...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined),
...(params?.search?.query ? { search: params.search.query } : undefined),
...(params?.search?.inferenceProviders
? { inference_provider: params.search.inferenceProviders.join(",") }
: undefined),
}),
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
]).toString();
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`;
while (url) {
const res: Response = await (params?.fetch ?? fetch)(url, {
headers: {
accept: "application/json",
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined),
},
});
if (!res.ok) {
throw await createApiError(res);
}
const items: ApiModelInfo[] = await res.json();
for (const item of items) {
yield {
...(params?.additionalFields && pick(item, params.additionalFields)),
id: item._id,
name: item.id,
private: item.private,
task: item.pipeline_tag,
downloads: item.downloads,
gated: item.gated,
likes: item.likes,
updatedAt: new Date(item.lastModified),
} as ModelEntry & Pick<ApiModelInfo, T>;
totalToFetch--;
if (totalToFetch <= 0) {
return;
}
}
const linkHeader = res.headers.get("Link");
url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
// Could update url to reduce the limit if we don't need the whole 500 of the next batch.
}
}
|