Spaces:
Running
on
Zero
Running
on
Zero
from diffusers import DDPMScheduler, DiffusionPipeline | |
from typing import List, Any, Union, Type | |
from utils.loader import get_class | |
from copy import deepcopy | |
from modules.loader.module_loader_config import ModuleLoaderConfig | |
import torch | |
import pytorch_lightning as pl | |
import jsonargparse | |
class bcolors: | |
HEADER = '\033[95m' | |
OKBLUE = '\033[94m' | |
OKCYAN = '\033[96m' | |
OKGREEN = '\033[92m' | |
WARNING = '\033[93m' | |
FAIL = '\033[91m' | |
ENDC = '\033[0m' | |
BOLD = '\033[1m' | |
UNDERLINE = '\033[4m' | |
class GenericModuleLoader(): | |
def __init__(self, | |
pipeline_repo: str = None, | |
pipeline_obj: str = None, | |
set_prediction_type: str = "", | |
module_names: List[str] = [ | |
"scheduler", "text_encoder", "tokenizer", "vae", "unet",], | |
module_config: dict[str, | |
Union[ModuleLoaderConfig, torch.nn.Module, Any]] = None, | |
fast_dev_run: Union[int, bool] = False, | |
root_cls: Type[Any] = None, | |
) -> None: | |
self.module_config = module_config | |
self.pipeline_repo = pipeline_repo | |
self.pipeline_obj = pipeline_obj | |
self.set_prediction_type = set_prediction_type | |
self.module_names = module_names | |
self.fast_dev_run = fast_dev_run | |
self.root_cls = root_cls | |
def load_custom_scheduler(self): | |
module_obj = DDPMScheduler.from_pretrained( | |
self.pipeline_repo, subfolder="scheduler") | |
if len(self.set_prediction_type) > 0: | |
scheduler_config = module_obj.load_config( | |
self.pipeline_repo, subfolder="scheduler") | |
scheduler_config["prediction_type"] = self.set_prediction_type | |
module_obj = module_obj.from_config(scheduler_config) | |
return module_obj | |
def load_pipeline(self): | |
return DiffusionPipeline.from_pretrained(self.pipeline_repo) if self.pipeline_repo is not None else None | |
def __call__(self, trainer: pl.LightningModule, diff_trainer_params): | |
# load diffusers pipeline object if set | |
if self.pipeline_obj is not None: | |
pipe = self.load_pipeline() | |
else: | |
pipe = None | |
if pipe is not None and self.pipeline_obj is not None: | |
# store the entire diffusers pipeline object under the name given by pipeline_obj | |
setattr(trainer, self.pipeline_obj, self.load_pipeline()) | |
for module_name in self.module_names: | |
print(f" --- START: Loading module: {module_name} ---") | |
if module_name not in self.module_config.keys() and pipe is not None: | |
# stores models from already loaded diffusers pipeline | |
module_obj = getattr(pipe, module_name) | |
if module_name == "scheduler": | |
module_obj = self.load_custom_scheduler() | |
setattr(trainer, module_name, module_obj) | |
else: | |
if not isinstance(self.module_config[module_name], ModuleLoaderConfig): | |
# instantiate model by jsonargparse and store it | |
module = self.module_config[module_name] | |
# TODO we want to be able to load ckpt still. | |
config_obj = None | |
else: | |
# instantiate object from class method (as used by Diffusers, e.g. DiffusionPipeline.load_from_pretrained) | |
config_obj = self.module_config[module_name] | |
# retrieve loader class | |
loader_cls = get_class( | |
config_obj.loader_cls_path) | |
# retrieve loader method | |
if config_obj.cls_func != "": | |
# we allow to specify a method for fast loading (e.g. in diffusers, from_config instead of from_pretrained) | |
# makes loading faster for quick testing | |
if not self.fast_dev_run or config_obj.cls_func_fast_dev_run == "": | |
cls_func = getattr( | |
loader_cls, config_obj.cls_func) | |
else: | |
print( | |
f"Model {module_name}: loading fast_dev_run class loader") | |
cls_func = getattr( | |
loader_cls, config_obj.cls_func_fast_dev_run) | |
else: | |
cls_func = loader_cls | |
# retrieve parameters | |
# load parameters specified in diff_trainer_params (so it links them) | |
kwargs_trainer_params = config_obj.kwargs_diff_trainer_params | |
kwargs_diffusers = config_obj.kwargs_diffusers | |
# names of dependent modules that we need as input | |
dependent_modules = config_obj.dependent_modules | |
# names of dependent modules that we need as input. Modules will be cloned | |
dependent_modules_cloned = config_obj.dependent_modules_cloned | |
# model kwargs. Can be just a dict, or a parameter class (derived from modules.params.params_mixin.AsDictMixin) so we have verification of inputs | |
model_params = config_obj.model_params | |
# kwargs used only if on fast_dev_run mode | |
model_params_fast_dev_run = config_obj.model_params_fast_dev_run | |
if model_params is not None: | |
if isinstance(model_params, dict): | |
model_dict = model_params | |
else: | |
model_dict = model_params.to_dict() | |
else: | |
model_dict = {} | |
if (model_params_fast_dev_run is None) or (not self.fast_dev_run): | |
model_params_fast_dev_run = {} | |
else: | |
print( | |
f"Module {module_name}: loading fast_dev_run params") | |
loaded_modules_dict = {} | |
if dependent_modules is not None: | |
for key, dependent_module in dependent_modules.items(): | |
assert hasattr( | |
trainer, dependent_module), f"Module {dependent_module} not available. Set {dependent_module} before module {module_name} in module_loader.module_names. Current order: {self.module_names}" | |
loaded_modules_dict[key] = getattr( | |
trainer, dependent_module) | |
if dependent_modules_cloned is not None: | |
for key, dependent_module in dependent_modules_cloned.items(): | |
assert hasattr( | |
trainer, dependent_module), f"Module {dependent_module} not available. Set {dependent_module} before module {module_name} in module_loader.module_names. Current order: {self.module_names}" | |
loaded_modules_dict[key] = getattr( | |
trainer, deepcopy(dependent_module)) | |
if kwargs_trainer_params is not None: | |
for key, param in kwargs_trainer_params.items(): | |
if param is not None: | |
kwargs_trainer_params[key] = getattr( | |
diff_trainer_params, param) | |
else: | |
kwargs_trainer_params[key] = diff_trainer_params | |
else: | |
kwargs_trainer_params = {} | |
if kwargs_diffusers is None: | |
kwargs_diffusers = {} | |
else: | |
for key, value in kwargs_diffusers.items(): | |
if key == "torch_dtype": | |
if value == "torch.float16": | |
kwargs_diffusers[key] = torch.float16 | |
kwargs = kwargs_diffusers | loaded_modules_dict | kwargs_trainer_params | model_dict | model_params_fast_dev_run | |
args = config_obj.args | |
# instantiate object | |
module = cls_func(*args, **kwargs) | |
module: torch.nn.Module | |
if self.root_cls is not None: | |
assert isinstance(module, self.root_cls) | |
if config_obj is not None and config_obj.state_dict_path != "" and not self.fast_dev_run: | |
# TODO extend loading to hf spaces | |
print( | |
f" * Loading checkpoint {config_obj.state_dict_path} - STARTED") | |
module_state_dict = torch.load( | |
config_obj.state_dict_path, map_location=torch.device("cpu")) | |
module_state_dict = module_state_dict["state_dict"] | |
if len(config_obj.state_dict_filters) > 0: | |
assert not config_obj.strict_loading | |
ckpt_params_dict = {} | |
for name, param in module.named_parameters(prefix=module_name): | |
for filter_str in config_obj.state_dict_filters: | |
filter_groups = filter_str.split("*") | |
has_all_parts = True | |
for filter_group in filter_groups: | |
has_all_parts = has_all_parts and filter_group in name | |
if has_all_parts: | |
validate_name = name | |
for filter_group in filter_groups: | |
if filter_group in validate_name: | |
shift = validate_name.index( | |
filter_group) | |
validate_name = validate_name[shift+len( | |
filter_group):] | |
else: | |
has_all_parts = False | |
break | |
if has_all_parts: | |
ckpt_params_dict[name[len( | |
module_name+"."):]] = param | |
else: | |
ckpt_params_dict = dict(filter(lambda x: x[0].startswith( | |
module_name), module_state_dict.items())) | |
ckpt_params_dict = { | |
k.split(module_name+".")[1]: v for (k, v) in ckpt_params_dict.items()} | |
if len(ckpt_params_dict) > 0: | |
miss, unex = module.load_state_dict( | |
ckpt_params_dict, strict=config_obj.strict_loading) | |
ckpt_params_dict = {} | |
assert len( | |
unex) == 0, f"Unexpected parameters in checkpoint: {unex}" | |
if len(miss) > 0: | |
print( | |
f"Checkpoint {config_obj.state_dict_path} is missing parameters for module {module_name}.") | |
print(miss) | |
print( | |
f" * Loading checkpoint {config_obj.state_dict_path} - FINISHED") | |
if isinstance(module, jsonargparse.Namespace) or isinstance(module, dict): | |
print(bcolors.WARNING + | |
f"Warning: Seems object {module_name} was not build correct." + bcolors.ENDC) | |
setattr(trainer, module_name, module) | |
print(f" --- FINSHED: Loading module: {module_name} ---") | |