|
import os |
|
from typing import Dict, Optional, Union |
|
|
|
import safetensors |
|
import torch |
|
from diffusers.utils import _get_model_file, logging |
|
from safetensors import safe_open |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class CustomAdapterMixin: |
|
def init_custom_adapter(self, *args, **kwargs): |
|
self._init_custom_adapter(*args, **kwargs) |
|
|
|
def _init_custom_adapter(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
def load_custom_adapter( |
|
self, |
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], |
|
weight_name: str, |
|
subfolder: Optional[str] = None, |
|
**kwargs, |
|
): |
|
|
|
cache_dir = kwargs.pop("cache_dir", None) |
|
force_download = kwargs.pop("force_download", False) |
|
proxies = kwargs.pop("proxies", None) |
|
local_files_only = kwargs.pop("local_files_only", None) |
|
token = kwargs.pop("token", None) |
|
revision = kwargs.pop("revision", None) |
|
|
|
user_agent = { |
|
"file_type": "attn_procs_weights", |
|
"framework": "pytorch", |
|
} |
|
|
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict): |
|
model_file = _get_model_file( |
|
pretrained_model_name_or_path_or_dict, |
|
weights_name=weight_name, |
|
subfolder=subfolder, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
local_files_only=local_files_only, |
|
token=token, |
|
revision=revision, |
|
user_agent=user_agent, |
|
) |
|
if weight_name.endswith(".safetensors"): |
|
state_dict = {} |
|
with safe_open(model_file, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
state_dict[key] = f.get_tensor(key) |
|
else: |
|
state_dict = torch.load(model_file, map_location="cpu") |
|
else: |
|
state_dict = pretrained_model_name_or_path_or_dict |
|
|
|
self._load_custom_adapter(state_dict) |
|
|
|
def _load_custom_adapter(self, state_dict): |
|
raise NotImplementedError |
|
|
|
def save_custom_adapter( |
|
self, |
|
save_directory: Union[str, os.PathLike], |
|
weight_name: str, |
|
safe_serialization: bool = False, |
|
**kwargs, |
|
): |
|
if os.path.isfile(save_directory): |
|
logger.error( |
|
f"Provided path ({save_directory}) should be a directory, not a file" |
|
) |
|
return |
|
|
|
if safe_serialization: |
|
|
|
def save_function(weights, filename): |
|
return safetensors.torch.save_file( |
|
weights, filename, metadata={"format": "pt"} |
|
) |
|
|
|
else: |
|
save_function = torch.save |
|
|
|
|
|
state_dict = self._save_custom_adapter(**kwargs) |
|
save_function(state_dict, os.path.join(save_directory, weight_name)) |
|
logger.info( |
|
f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}" |
|
) |
|
|
|
def _save_custom_adapter(self): |
|
raise NotImplementedError |
|
|