StreamingSVD / modules /loader /module_loader.py
lev1's picture
Initial commit
8fd2f2f
raw
history blame
12 kB
from diffusers import DDPMScheduler, DiffusionPipeline
from typing import List, Any, Union, Type
from utils.loader import get_class
from copy import deepcopy
from modules.loader.module_loader_config import ModuleLoaderConfig
import torch
import pytorch_lightning as pl
import jsonargparse
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
class GenericModuleLoader():
def __init__(self,
pipeline_repo: str = None,
pipeline_obj: str = None,
set_prediction_type: str = "",
module_names: List[str] = [
"scheduler", "text_encoder", "tokenizer", "vae", "unet",],
module_config: dict[str,
Union[ModuleLoaderConfig, torch.nn.Module, Any]] = None,
fast_dev_run: Union[int, bool] = False,
root_cls: Type[Any] = None,
) -> None:
self.module_config = module_config
self.pipeline_repo = pipeline_repo
self.pipeline_obj = pipeline_obj
self.set_prediction_type = set_prediction_type
self.module_names = module_names
self.fast_dev_run = fast_dev_run
self.root_cls = root_cls
def load_custom_scheduler(self):
module_obj = DDPMScheduler.from_pretrained(
self.pipeline_repo, subfolder="scheduler")
if len(self.set_prediction_type) > 0:
scheduler_config = module_obj.load_config(
self.pipeline_repo, subfolder="scheduler")
scheduler_config["prediction_type"] = self.set_prediction_type
module_obj = module_obj.from_config(scheduler_config)
return module_obj
def load_pipeline(self):
return DiffusionPipeline.from_pretrained(self.pipeline_repo) if self.pipeline_repo is not None else None
def __call__(self, trainer: pl.LightningModule, diff_trainer_params):
# load diffusers pipeline object if set
if self.pipeline_obj is not None:
pipe = self.load_pipeline()
else:
pipe = None
if pipe is not None and self.pipeline_obj is not None:
# store the entire diffusers pipeline object under the name given by pipeline_obj
setattr(trainer, self.pipeline_obj, self.load_pipeline())
for module_name in self.module_names:
print(f" --- START: Loading module: {module_name} ---")
if module_name not in self.module_config.keys() and pipe is not None:
# stores models from already loaded diffusers pipeline
module_obj = getattr(pipe, module_name)
if module_name == "scheduler":
module_obj = self.load_custom_scheduler()
setattr(trainer, module_name, module_obj)
else:
if not isinstance(self.module_config[module_name], ModuleLoaderConfig):
# instantiate model by jsonargparse and store it
module = self.module_config[module_name]
# TODO we want to be able to load ckpt still.
config_obj = None
else:
# instantiate object from class method (as used by Diffusers, e.g. DiffusionPipeline.load_from_pretrained)
config_obj = self.module_config[module_name]
# retrieve loader class
loader_cls = get_class(
config_obj.loader_cls_path)
# retrieve loader method
if config_obj.cls_func != "":
# we allow to specify a method for fast loading (e.g. in diffusers, from_config instead of from_pretrained)
# makes loading faster for quick testing
if not self.fast_dev_run or config_obj.cls_func_fast_dev_run == "":
cls_func = getattr(
loader_cls, config_obj.cls_func)
else:
print(
f"Model {module_name}: loading fast_dev_run class loader")
cls_func = getattr(
loader_cls, config_obj.cls_func_fast_dev_run)
else:
cls_func = loader_cls
# retrieve parameters
# load parameters specified in diff_trainer_params (so it links them)
kwargs_trainer_params = config_obj.kwargs_diff_trainer_params
kwargs_diffusers = config_obj.kwargs_diffusers
# names of dependent modules that we need as input
dependent_modules = config_obj.dependent_modules
# names of dependent modules that we need as input. Modules will be cloned
dependent_modules_cloned = config_obj.dependent_modules_cloned
# model kwargs. Can be just a dict, or a parameter class (derived from modules.params.params_mixin.AsDictMixin) so we have verification of inputs
model_params = config_obj.model_params
# kwargs used only if on fast_dev_run mode
model_params_fast_dev_run = config_obj.model_params_fast_dev_run
if model_params is not None:
if isinstance(model_params, dict):
model_dict = model_params
else:
model_dict = model_params.to_dict()
else:
model_dict = {}
if (model_params_fast_dev_run is None) or (not self.fast_dev_run):
model_params_fast_dev_run = {}
else:
print(
f"Module {module_name}: loading fast_dev_run params")
loaded_modules_dict = {}
if dependent_modules is not None:
for key, dependent_module in dependent_modules.items():
assert hasattr(
trainer, dependent_module), f"Module {dependent_module} not available. Set {dependent_module} before module {module_name} in module_loader.module_names. Current order: {self.module_names}"
loaded_modules_dict[key] = getattr(
trainer, dependent_module)
if dependent_modules_cloned is not None:
for key, dependent_module in dependent_modules_cloned.items():
assert hasattr(
trainer, dependent_module), f"Module {dependent_module} not available. Set {dependent_module} before module {module_name} in module_loader.module_names. Current order: {self.module_names}"
loaded_modules_dict[key] = getattr(
trainer, deepcopy(dependent_module))
if kwargs_trainer_params is not None:
for key, param in kwargs_trainer_params.items():
if param is not None:
kwargs_trainer_params[key] = getattr(
diff_trainer_params, param)
else:
kwargs_trainer_params[key] = diff_trainer_params
else:
kwargs_trainer_params = {}
if kwargs_diffusers is None:
kwargs_diffusers = {}
else:
for key, value in kwargs_diffusers.items():
if key == "torch_dtype":
if value == "torch.float16":
kwargs_diffusers[key] = torch.float16
kwargs = kwargs_diffusers | loaded_modules_dict | kwargs_trainer_params | model_dict | model_params_fast_dev_run
args = config_obj.args
# instantiate object
module = cls_func(*args, **kwargs)
module: torch.nn.Module
if self.root_cls is not None:
assert isinstance(module, self.root_cls)
if config_obj is not None and config_obj.state_dict_path != "" and not self.fast_dev_run:
# TODO extend loading to hf spaces
print(
f" * Loading checkpoint {config_obj.state_dict_path} - STARTED")
module_state_dict = torch.load(
config_obj.state_dict_path, map_location=torch.device("cpu"))
module_state_dict = module_state_dict["state_dict"]
if len(config_obj.state_dict_filters) > 0:
assert not config_obj.strict_loading
ckpt_params_dict = {}
for name, param in module.named_parameters(prefix=module_name):
for filter_str in config_obj.state_dict_filters:
filter_groups = filter_str.split("*")
has_all_parts = True
for filter_group in filter_groups:
has_all_parts = has_all_parts and filter_group in name
if has_all_parts:
validate_name = name
for filter_group in filter_groups:
if filter_group in validate_name:
shift = validate_name.index(
filter_group)
validate_name = validate_name[shift+len(
filter_group):]
else:
has_all_parts = False
break
if has_all_parts:
ckpt_params_dict[name[len(
module_name+"."):]] = param
else:
ckpt_params_dict = dict(filter(lambda x: x[0].startswith(
module_name), module_state_dict.items()))
ckpt_params_dict = {
k.split(module_name+".")[1]: v for (k, v) in ckpt_params_dict.items()}
if len(ckpt_params_dict) > 0:
miss, unex = module.load_state_dict(
ckpt_params_dict, strict=config_obj.strict_loading)
ckpt_params_dict = {}
assert len(
unex) == 0, f"Unexpected parameters in checkpoint: {unex}"
if len(miss) > 0:
print(
f"Checkpoint {config_obj.state_dict_path} is missing parameters for module {module_name}.")
print(miss)
print(
f" * Loading checkpoint {config_obj.state_dict_path} - FINISHED")
if isinstance(module, jsonargparse.Namespace) or isinstance(module, dict):
print(bcolors.WARNING +
f"Warning: Seems object {module_name} was not build correct." + bcolors.ENDC)
setattr(trainer, module_name, module)
print(f" --- FINSHED: Loading module: {module_name} ---")