Spaces:
Runtime error
Runtime error
| from diffusers import DDPMScheduler, DiffusionPipeline | |
| from typing import List, Any, Union, Type | |
| from utils.loader import get_class | |
| from copy import deepcopy | |
| from modules.loader.module_loader_config import ModuleLoaderConfig | |
| import torch | |
| import pytorch_lightning as pl | |
| import jsonargparse | |
| class bcolors: | |
| HEADER = '\033[95m' | |
| OKBLUE = '\033[94m' | |
| OKCYAN = '\033[96m' | |
| OKGREEN = '\033[92m' | |
| WARNING = '\033[93m' | |
| FAIL = '\033[91m' | |
| ENDC = '\033[0m' | |
| BOLD = '\033[1m' | |
| UNDERLINE = '\033[4m' | |
| class GenericModuleLoader(): | |
| def __init__(self, | |
| pipeline_repo: str = None, | |
| pipeline_obj: str = None, | |
| set_prediction_type: str = "", | |
| module_names: List[str] = [ | |
| "scheduler", "text_encoder", "tokenizer", "vae", "unet",], | |
| module_config: dict[str, | |
| Union[ModuleLoaderConfig, torch.nn.Module, Any]] = None, | |
| fast_dev_run: Union[int, bool] = False, | |
| root_cls: Type[Any] = None, | |
| ) -> None: | |
| self.module_config = module_config | |
| self.pipeline_repo = pipeline_repo | |
| self.pipeline_obj = pipeline_obj | |
| self.set_prediction_type = set_prediction_type | |
| self.module_names = module_names | |
| self.fast_dev_run = fast_dev_run | |
| self.root_cls = root_cls | |
| def load_custom_scheduler(self): | |
| module_obj = DDPMScheduler.from_pretrained( | |
| self.pipeline_repo, subfolder="scheduler") | |
| if len(self.set_prediction_type) > 0: | |
| scheduler_config = module_obj.load_config( | |
| self.pipeline_repo, subfolder="scheduler") | |
| scheduler_config["prediction_type"] = self.set_prediction_type | |
| module_obj = module_obj.from_config(scheduler_config) | |
| return module_obj | |
| def load_pipeline(self): | |
| return DiffusionPipeline.from_pretrained(self.pipeline_repo) if self.pipeline_repo is not None else None | |
| def __call__(self, trainer: pl.LightningModule, diff_trainer_params): | |
| # load diffusers pipeline object if set | |
| if self.pipeline_obj is not None: | |
| pipe = self.load_pipeline() | |
| else: | |
| pipe = None | |
| if pipe is not None and self.pipeline_obj is not None: | |
| # store the entire diffusers pipeline object under the name given by pipeline_obj | |
| setattr(trainer, self.pipeline_obj, self.load_pipeline()) | |
| for module_name in self.module_names: | |
| print(f" --- START: Loading module: {module_name} ---") | |
| if module_name not in self.module_config.keys() and pipe is not None: | |
| # stores models from already loaded diffusers pipeline | |
| module_obj = getattr(pipe, module_name) | |
| if module_name == "scheduler": | |
| module_obj = self.load_custom_scheduler() | |
| setattr(trainer, module_name, module_obj) | |
| else: | |
| if not isinstance(self.module_config[module_name], ModuleLoaderConfig): | |
| # instantiate model by jsonargparse and store it | |
| module = self.module_config[module_name] | |
| # TODO we want to be able to load ckpt still. | |
| config_obj = None | |
| else: | |
| # instantiate object from class method (as used by Diffusers, e.g. DiffusionPipeline.load_from_pretrained) | |
| config_obj = self.module_config[module_name] | |
| # retrieve loader class | |
| loader_cls = get_class( | |
| config_obj.loader_cls_path) | |
| # retrieve loader method | |
| if config_obj.cls_func != "": | |
| # we allow to specify a method for fast loading (e.g. in diffusers, from_config instead of from_pretrained) | |
| # makes loading faster for quick testing | |
| if not self.fast_dev_run or config_obj.cls_func_fast_dev_run == "": | |
| cls_func = getattr( | |
| loader_cls, config_obj.cls_func) | |
| else: | |
| print( | |
| f"Model {module_name}: loading fast_dev_run class loader") | |
| cls_func = getattr( | |
| loader_cls, config_obj.cls_func_fast_dev_run) | |
| else: | |
| cls_func = loader_cls | |
| # retrieve parameters | |
| # load parameters specified in diff_trainer_params (so it links them) | |
| kwargs_trainer_params = config_obj.kwargs_diff_trainer_params | |
| kwargs_diffusers = config_obj.kwargs_diffusers | |
| # names of dependent modules that we need as input | |
| dependent_modules = config_obj.dependent_modules | |
| # names of dependent modules that we need as input. Modules will be cloned | |
| dependent_modules_cloned = config_obj.dependent_modules_cloned | |
| # model kwargs. Can be just a dict, or a parameter class (derived from modules.params.params_mixin.AsDictMixin) so we have verification of inputs | |
| model_params = config_obj.model_params | |
| # kwargs used only if on fast_dev_run mode | |
| model_params_fast_dev_run = config_obj.model_params_fast_dev_run | |
| if model_params is not None: | |
| if isinstance(model_params, dict): | |
| model_dict = model_params | |
| else: | |
| model_dict = model_params.to_dict() | |
| else: | |
| model_dict = {} | |
| if (model_params_fast_dev_run is None) or (not self.fast_dev_run): | |
| model_params_fast_dev_run = {} | |
| else: | |
| print( | |
| f"Module {module_name}: loading fast_dev_run params") | |
| loaded_modules_dict = {} | |
| if dependent_modules is not None: | |
| for key, dependent_module in dependent_modules.items(): | |
| assert hasattr( | |
| trainer, dependent_module), f"Module {dependent_module} not available. Set {dependent_module} before module {module_name} in module_loader.module_names. Current order: {self.module_names}" | |
| loaded_modules_dict[key] = getattr( | |
| trainer, dependent_module) | |
| if dependent_modules_cloned is not None: | |
| for key, dependent_module in dependent_modules_cloned.items(): | |
| assert hasattr( | |
| trainer, dependent_module), f"Module {dependent_module} not available. Set {dependent_module} before module {module_name} in module_loader.module_names. Current order: {self.module_names}" | |
| loaded_modules_dict[key] = getattr( | |
| trainer, deepcopy(dependent_module)) | |
| if kwargs_trainer_params is not None: | |
| for key, param in kwargs_trainer_params.items(): | |
| if param is not None: | |
| kwargs_trainer_params[key] = getattr( | |
| diff_trainer_params, param) | |
| else: | |
| kwargs_trainer_params[key] = diff_trainer_params | |
| else: | |
| kwargs_trainer_params = {} | |
| if kwargs_diffusers is None: | |
| kwargs_diffusers = {} | |
| else: | |
| for key, value in kwargs_diffusers.items(): | |
| if key == "torch_dtype": | |
| if value == "torch.float16": | |
| kwargs_diffusers[key] = torch.float16 | |
| kwargs = kwargs_diffusers | loaded_modules_dict | kwargs_trainer_params | model_dict | model_params_fast_dev_run | |
| args = config_obj.args | |
| # instantiate object | |
| module = cls_func(*args, **kwargs) | |
| module: torch.nn.Module | |
| if self.root_cls is not None: | |
| assert isinstance(module, self.root_cls) | |
| if config_obj is not None and config_obj.state_dict_path != "" and not self.fast_dev_run: | |
| # TODO extend loading to hf spaces | |
| print( | |
| f" * Loading checkpoint {config_obj.state_dict_path} - STARTED") | |
| module_state_dict = torch.load( | |
| config_obj.state_dict_path, map_location=torch.device("cpu")) | |
| module_state_dict = module_state_dict["state_dict"] | |
| if len(config_obj.state_dict_filters) > 0: | |
| assert not config_obj.strict_loading | |
| ckpt_params_dict = {} | |
| for name, param in module.named_parameters(prefix=module_name): | |
| for filter_str in config_obj.state_dict_filters: | |
| filter_groups = filter_str.split("*") | |
| has_all_parts = True | |
| for filter_group in filter_groups: | |
| has_all_parts = has_all_parts and filter_group in name | |
| if has_all_parts: | |
| validate_name = name | |
| for filter_group in filter_groups: | |
| if filter_group in validate_name: | |
| shift = validate_name.index( | |
| filter_group) | |
| validate_name = validate_name[shift+len( | |
| filter_group):] | |
| else: | |
| has_all_parts = False | |
| break | |
| if has_all_parts: | |
| ckpt_params_dict[name[len( | |
| module_name+"."):]] = param | |
| else: | |
| ckpt_params_dict = dict(filter(lambda x: x[0].startswith( | |
| module_name), module_state_dict.items())) | |
| ckpt_params_dict = { | |
| k.split(module_name+".")[1]: v for (k, v) in ckpt_params_dict.items()} | |
| if len(ckpt_params_dict) > 0: | |
| miss, unex = module.load_state_dict( | |
| ckpt_params_dict, strict=config_obj.strict_loading) | |
| ckpt_params_dict = {} | |
| assert len( | |
| unex) == 0, f"Unexpected parameters in checkpoint: {unex}" | |
| if len(miss) > 0: | |
| print( | |
| f"Checkpoint {config_obj.state_dict_path} is missing parameters for module {module_name}.") | |
| print(miss) | |
| print( | |
| f" * Loading checkpoint {config_obj.state_dict_path} - FINISHED") | |
| if isinstance(module, jsonargparse.Namespace) or isinstance(module, dict): | |
| print(bcolors.WARNING + | |
| f"Warning: Seems object {module_name} was not build correct." + bcolors.ENDC) | |
| setattr(trainer, module_name, module) | |
| print(f" --- FINSHED: Loading module: {module_name} ---") | |