| import hashlib |
| from io import BytesIO |
| from typing import Optional |
|
|
| import safetensors.torch |
| import torch |
|
|
|
|
| def model_hash(filename): |
| """Old model hash used by stable-diffusion-webui""" |
| try: |
| with open(filename, "rb") as file: |
| m = hashlib.sha256() |
|
|
| file.seek(0x100000) |
| m.update(file.read(0x10000)) |
| return m.hexdigest()[0:8] |
| except FileNotFoundError: |
| return "NOFILE" |
| except IsADirectoryError: |
| return "IsADirectory" |
| except PermissionError: |
| return "IsADirectory" |
|
|
|
|
| def calculate_sha256(filename): |
| """New model hash used by stable-diffusion-webui""" |
| try: |
| hash_sha256 = hashlib.sha256() |
| blksize = 1024 * 1024 |
|
|
| with open(filename, "rb") as f: |
| for chunk in iter(lambda: f.read(blksize), b""): |
| hash_sha256.update(chunk) |
|
|
| return hash_sha256.hexdigest() |
| except FileNotFoundError: |
| return "NOFILE" |
| except IsADirectoryError: |
| return "IsADirectory" |
| except PermissionError: |
| return "IsADirectory" |
|
|
|
|
| def addnet_hash_legacy(b): |
| """Old model hash used by sd-webui-additional-networks for .safetensors format files""" |
| m = hashlib.sha256() |
|
|
| b.seek(0x100000) |
| m.update(b.read(0x10000)) |
| return m.hexdigest()[0:8] |
|
|
|
|
| def addnet_hash_safetensors(b): |
| """New model hash used by sd-webui-additional-networks for .safetensors format files""" |
| hash_sha256 = hashlib.sha256() |
| blksize = 1024 * 1024 |
|
|
| b.seek(0) |
| header = b.read(8) |
| n = int.from_bytes(header, "little") |
|
|
| offset = n + 8 |
| b.seek(offset) |
| for chunk in iter(lambda: b.read(blksize), b""): |
| hash_sha256.update(chunk) |
|
|
| return hash_sha256.hexdigest() |
|
|
|
|
| def precalculate_safetensors_hashes(tensors, metadata): |
| """Precalculate the model hashes needed by sd-webui-additional-networks to |
| save time on indexing the model later.""" |
|
|
| |
| |
| |
| metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} |
|
|
| bytes = safetensors.torch.save(tensors, metadata) |
| b = BytesIO(bytes) |
|
|
| model_hash = addnet_hash_safetensors(b) |
| legacy_hash = addnet_hash_legacy(b) |
| return model_hash, legacy_hash |
|
|
|
|
| def dtype_to_str(dtype: torch.dtype) -> str: |
| |
| dtype_name = str(dtype).split(".")[-1] |
| return dtype_name |
|
|
|
|
| def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: |
| """ |
| Convert a string to a torch.dtype |
| |
| Args: |
| s: string representation of the dtype |
| default_dtype: default dtype to return if s is None |
| |
| Returns: |
| torch.dtype: the corresponding torch.dtype |
| |
| Raises: |
| ValueError: if the dtype is not supported |
| |
| Examples: |
| >>> str_to_dtype("float32") |
| torch.float32 |
| >>> str_to_dtype("fp32") |
| torch.float32 |
| >>> str_to_dtype("float16") |
| torch.float16 |
| >>> str_to_dtype("fp16") |
| torch.float16 |
| >>> str_to_dtype("bfloat16") |
| torch.bfloat16 |
| >>> str_to_dtype("bf16") |
| torch.bfloat16 |
| >>> str_to_dtype("fp8") |
| torch.float8_e4m3fn |
| >>> str_to_dtype("fp8_e4m3fn") |
| torch.float8_e4m3fn |
| >>> str_to_dtype("fp8_e4m3fnuz") |
| torch.float8_e4m3fnuz |
| >>> str_to_dtype("fp8_e5m2") |
| torch.float8_e5m2 |
| >>> str_to_dtype("fp8_e5m2fnuz") |
| torch.float8_e5m2fnuz |
| """ |
| if s is None: |
| return default_dtype |
| if s in ["bf16", "bfloat16"]: |
| return torch.bfloat16 |
| elif s in ["fp16", "float16"]: |
| return torch.float16 |
| elif s in ["fp32", "float32", "float"]: |
| return torch.float32 |
| elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: |
| return torch.float8_e4m3fn |
| elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: |
| return torch.float8_e4m3fnuz |
| elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: |
| return torch.float8_e5m2 |
| elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: |
| return torch.float8_e5m2fnuz |
| elif s in ["fp8", "float8"]: |
| return torch.float8_e4m3fn |
| else: |
| raise ValueError(f"Unsupported dtype: {s}") |
|
|