| | import hashlib |
| | from io import BytesIO |
| | from typing import Optional |
| |
|
| | import safetensors.torch |
| | import torch |
| |
|
| |
|
| | def model_hash(filename): |
| | """Old model hash used by stable-diffusion-webui""" |
| | try: |
| | with open(filename, "rb") as file: |
| | m = hashlib.sha256() |
| |
|
| | file.seek(0x100000) |
| | m.update(file.read(0x10000)) |
| | return m.hexdigest()[0:8] |
| | except FileNotFoundError: |
| | return "NOFILE" |
| | except IsADirectoryError: |
| | return "IsADirectory" |
| | except PermissionError: |
| | return "IsADirectory" |
| |
|
| |
|
| | def calculate_sha256(filename): |
| | """New model hash used by stable-diffusion-webui""" |
| | try: |
| | hash_sha256 = hashlib.sha256() |
| | blksize = 1024 * 1024 |
| |
|
| | with open(filename, "rb") as f: |
| | for chunk in iter(lambda: f.read(blksize), b""): |
| | hash_sha256.update(chunk) |
| |
|
| | return hash_sha256.hexdigest() |
| | except FileNotFoundError: |
| | return "NOFILE" |
| | except IsADirectoryError: |
| | return "IsADirectory" |
| | except PermissionError: |
| | return "IsADirectory" |
| |
|
| |
|
| | def addnet_hash_legacy(b): |
| | """Old model hash used by sd-webui-additional-networks for .safetensors format files""" |
| | m = hashlib.sha256() |
| |
|
| | b.seek(0x100000) |
| | m.update(b.read(0x10000)) |
| | return m.hexdigest()[0:8] |
| |
|
| |
|
| | def addnet_hash_safetensors(b): |
| | """New model hash used by sd-webui-additional-networks for .safetensors format files""" |
| | hash_sha256 = hashlib.sha256() |
| | blksize = 1024 * 1024 |
| |
|
| | b.seek(0) |
| | header = b.read(8) |
| | n = int.from_bytes(header, "little") |
| |
|
| | offset = n + 8 |
| | b.seek(offset) |
| | for chunk in iter(lambda: b.read(blksize), b""): |
| | hash_sha256.update(chunk) |
| |
|
| | return hash_sha256.hexdigest() |
| |
|
| |
|
| | def precalculate_safetensors_hashes(tensors, metadata): |
| | """Precalculate the model hashes needed by sd-webui-additional-networks to |
| | save time on indexing the model later.""" |
| |
|
| | |
| | |
| | |
| | metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} |
| |
|
| | bytes = safetensors.torch.save(tensors, metadata) |
| | b = BytesIO(bytes) |
| |
|
| | model_hash = addnet_hash_safetensors(b) |
| | legacy_hash = addnet_hash_legacy(b) |
| | return model_hash, legacy_hash |
| |
|
| |
|
| | def dtype_to_str(dtype: torch.dtype) -> str: |
| | |
| | dtype_name = str(dtype).split(".")[-1] |
| | return dtype_name |
| |
|
| |
|
| | def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: |
| | """ |
| | Convert a string to a torch.dtype |
| | |
| | Args: |
| | s: string representation of the dtype |
| | default_dtype: default dtype to return if s is None |
| | |
| | Returns: |
| | torch.dtype: the corresponding torch.dtype |
| | |
| | Raises: |
| | ValueError: if the dtype is not supported |
| | |
| | Examples: |
| | >>> str_to_dtype("float32") |
| | torch.float32 |
| | >>> str_to_dtype("fp32") |
| | torch.float32 |
| | >>> str_to_dtype("float16") |
| | torch.float16 |
| | >>> str_to_dtype("fp16") |
| | torch.float16 |
| | >>> str_to_dtype("bfloat16") |
| | torch.bfloat16 |
| | >>> str_to_dtype("bf16") |
| | torch.bfloat16 |
| | >>> str_to_dtype("fp8") |
| | torch.float8_e4m3fn |
| | >>> str_to_dtype("fp8_e4m3fn") |
| | torch.float8_e4m3fn |
| | >>> str_to_dtype("fp8_e4m3fnuz") |
| | torch.float8_e4m3fnuz |
| | >>> str_to_dtype("fp8_e5m2") |
| | torch.float8_e5m2 |
| | >>> str_to_dtype("fp8_e5m2fnuz") |
| | torch.float8_e5m2fnuz |
| | """ |
| | if s is None: |
| | return default_dtype |
| | if s in ["bf16", "bfloat16"]: |
| | return torch.bfloat16 |
| | elif s in ["fp16", "float16"]: |
| | return torch.float16 |
| | elif s in ["fp32", "float32", "float"]: |
| | return torch.float32 |
| | elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: |
| | return torch.float8_e4m3fn |
| | elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: |
| | return torch.float8_e4m3fnuz |
| | elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: |
| | return torch.float8_e5m2 |
| | elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: |
| | return torch.float8_e5m2fnuz |
| | elif s in ["fp8", "float8"]: |
| | return torch.float8_e4m3fn |
| | else: |
| | raise ValueError(f"Unsupported dtype: {s}") |
| |
|