|
import type { CredentialsParams, RepoDesignation } from "../types/public"; |
|
import { omit } from "../utils/omit"; |
|
import { toRepoId } from "../utils/toRepoId"; |
|
import { typedEntries } from "../utils/typedEntries"; |
|
import { downloadFile } from "./download-file"; |
|
import { fileExists } from "./file-exists"; |
|
import { promisesQueue } from "../utils/promisesQueue"; |
|
import type { SetRequired } from "../vendor/type-fest/set-required"; |
|
|
|
export const SAFETENSORS_FILE = "model.safetensors"; |
|
export const SAFETENSORS_INDEX_FILE = "model.safetensors.index.json"; |
|
|
|
|
|
export const RE_SAFETENSORS_FILE = /\.safetensors$/; |
|
export const RE_SAFETENSORS_INDEX_FILE = /\.safetensors\.index\.json$/; |
|
export const RE_SAFETENSORS_SHARD_FILE = |
|
/^(?<prefix>(?<basePrefix>.*?)[_-])(?<shard>\d{5})-of-(?<total>\d{5})\.safetensors$/; |
|
export interface SafetensorsShardFileInfo { |
|
prefix: string; |
|
basePrefix: string; |
|
shard: string; |
|
total: string; |
|
} |
|
export function parseSafetensorsShardFilename(filename: string): SafetensorsShardFileInfo | null { |
|
const match = RE_SAFETENSORS_SHARD_FILE.exec(filename); |
|
if (match && match.groups) { |
|
return { |
|
prefix: match.groups["prefix"], |
|
basePrefix: match.groups["basePrefix"], |
|
shard: match.groups["shard"], |
|
total: match.groups["total"], |
|
}; |
|
} |
|
return null; |
|
} |
|
|
|
const PARALLEL_DOWNLOADS = 20; |
|
const MAX_HEADER_LENGTH = 25_000_000; |
|
|
|
class SafetensorParseError extends Error {} |
|
|
|
type FileName = string; |
|
|
|
export type TensorName = string; |
|
export type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL"; |
|
|
|
export interface TensorInfo { |
|
dtype: Dtype; |
|
shape: number[]; |
|
data_offsets: [number, number]; |
|
} |
|
|
|
export type SafetensorsFileHeader = Record<TensorName, TensorInfo> & { |
|
__metadata__: Record<string, string>; |
|
}; |
|
|
|
export interface SafetensorsIndexJson { |
|
dtype?: string; |
|
|
|
metadata?: Record<string, string>; |
|
|
|
weight_map: Record<TensorName, FileName>; |
|
} |
|
|
|
export type SafetensorsShardedHeaders = Record<FileName, SafetensorsFileHeader>; |
|
|
|
export type SafetensorsParseFromRepo = |
|
| { |
|
sharded: false; |
|
header: SafetensorsFileHeader; |
|
parameterCount?: Partial<Record<Dtype, number>>; |
|
} |
|
| { |
|
sharded: true; |
|
index: SafetensorsIndexJson; |
|
headers: SafetensorsShardedHeaders; |
|
parameterCount?: Partial<Record<Dtype, number>>; |
|
}; |
|
|
|
async function parseSingleFile( |
|
path: string, |
|
params: { |
|
repo: RepoDesignation; |
|
revision?: string; |
|
hubUrl?: string; |
|
/** |
|
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. |
|
*/ |
|
fetch?: typeof fetch; |
|
} & Partial<CredentialsParams> |
|
): Promise<SafetensorsFileHeader> { |
|
const blob = await downloadFile({ ...params, path }); |
|
|
|
if (!blob) { |
|
throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors header length.`); |
|
} |
|
|
|
const bufLengthOfHeaderLE = await blob.slice(0, 8).arrayBuffer(); |
|
const lengthOfHeader = new DataView(bufLengthOfHeaderLE).getBigUint64(0, true); |
|
|
|
if (lengthOfHeader <= 0) { |
|
throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is malformed.`); |
|
} |
|
if (lengthOfHeader > MAX_HEADER_LENGTH) { |
|
throw new SafetensorParseError( |
|
`Failed to parse file ${path}: safetensor header is too big. Maximum supported size is ${MAX_HEADER_LENGTH} bytes.` |
|
); |
|
} |
|
|
|
try { |
|
|
|
const header: SafetensorsFileHeader = JSON.parse(await blob.slice(8, 8 + Number(lengthOfHeader)).text()); |
|
return header; |
|
} catch (err) { |
|
throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is not valid JSON.`); |
|
} |
|
} |
|
|
|
async function parseShardedIndex( |
|
path: string, |
|
params: { |
|
repo: RepoDesignation; |
|
revision?: string; |
|
hubUrl?: string; |
|
/** |
|
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. |
|
*/ |
|
fetch?: typeof fetch; |
|
} & Partial<CredentialsParams> |
|
): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> { |
|
const indexBlob = await downloadFile({ |
|
...params, |
|
path, |
|
}); |
|
|
|
if (!indexBlob) { |
|
throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`); |
|
} |
|
|
|
|
|
let index: SafetensorsIndexJson; |
|
try { |
|
index = JSON.parse(await indexBlob.slice(0, 10_000_000).text()); |
|
} catch (error) { |
|
throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`); |
|
} |
|
|
|
const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1); |
|
const filenames = [...new Set(Object.values(index.weight_map))]; |
|
const shardedMap: SafetensorsShardedHeaders = Object.fromEntries( |
|
await promisesQueue( |
|
filenames.map( |
|
(filename) => async () => |
|
[filename, await parseSingleFile(pathPrefix + filename, params)] satisfies [string, SafetensorsFileHeader] |
|
), |
|
PARALLEL_DOWNLOADS |
|
) |
|
); |
|
return { index, headers: shardedMap }; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
export async function parseSafetensorsMetadata( |
|
params: { |
|
/** Only models are supported */ |
|
repo: RepoDesignation; |
|
/** |
|
* Relative file path to safetensors file inside `repo`. Defaults to `SAFETENSORS_FILE` or `SAFETENSORS_INDEX_FILE` (whichever one exists). |
|
*/ |
|
path?: string; |
|
/** |
|
* Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType |
|
* |
|
* @default false |
|
*/ |
|
computeParametersCount: true; |
|
hubUrl?: string; |
|
revision?: string; |
|
/** |
|
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. |
|
*/ |
|
fetch?: typeof fetch; |
|
} & Partial<CredentialsParams> |
|
): Promise<SetRequired<SafetensorsParseFromRepo, "parameterCount">>; |
|
export async function parseSafetensorsMetadata( |
|
params: { |
|
/** Only models are supported */ |
|
repo: RepoDesignation; |
|
/** |
|
* Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType |
|
* |
|
* @default false |
|
*/ |
|
path?: string; |
|
computeParametersCount?: boolean; |
|
hubUrl?: string; |
|
revision?: string; |
|
/** |
|
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. |
|
*/ |
|
fetch?: typeof fetch; |
|
} & Partial<CredentialsParams> |
|
): Promise<SafetensorsParseFromRepo>; |
|
export async function parseSafetensorsMetadata( |
|
params: { |
|
repo: RepoDesignation; |
|
path?: string; |
|
computeParametersCount?: boolean; |
|
hubUrl?: string; |
|
revision?: string; |
|
/** |
|
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. |
|
*/ |
|
fetch?: typeof fetch; |
|
} & Partial<CredentialsParams> |
|
): Promise<SafetensorsParseFromRepo> { |
|
const repoId = toRepoId(params.repo); |
|
|
|
if (repoId.type !== "model") { |
|
throw new TypeError("Only model repos should contain safetensors files."); |
|
} |
|
|
|
if (RE_SAFETENSORS_FILE.test(params.path ?? "") || (await fileExists({ ...params, path: SAFETENSORS_FILE }))) { |
|
const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params); |
|
return { |
|
sharded: false, |
|
header, |
|
...(params.computeParametersCount && { |
|
parameterCount: computeNumOfParamsByDtypeSingleFile(header), |
|
}), |
|
}; |
|
} else if ( |
|
RE_SAFETENSORS_INDEX_FILE.test(params.path ?? "") || |
|
(await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE })) |
|
) { |
|
const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params); |
|
return { |
|
sharded: true, |
|
index, |
|
headers, |
|
...(params.computeParametersCount && { |
|
parameterCount: computeNumOfParamsByDtypeSharded(headers), |
|
}), |
|
}; |
|
} else { |
|
throw new Error("model id does not seem to contain safetensors weights"); |
|
} |
|
} |
|
|
|
function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Partial<Record<Dtype, number>> { |
|
const counter: Partial<Record<Dtype, number>> = {}; |
|
const tensors = omit(header, "__metadata__"); |
|
|
|
for (const [, v] of typedEntries(tensors)) { |
|
if (v.shape.length === 0) { |
|
continue; |
|
} |
|
counter[v.dtype] = (counter[v.dtype] ?? 0) + v.shape.reduce((a, b) => a * b); |
|
} |
|
return counter; |
|
} |
|
|
|
function computeNumOfParamsByDtypeSharded(shardedMap: SafetensorsShardedHeaders): Partial<Record<Dtype, number>> { |
|
const counter: Partial<Record<Dtype, number>> = {}; |
|
for (const header of Object.values(shardedMap)) { |
|
for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header))) { |
|
counter[k] = (counter[k] ?? 0) + (v ?? 0); |
|
} |
|
} |
|
return counter; |
|
} |
|
|