import os.path as osp from glob import glob import torch from omegaconf import OmegaConf from diffusionsfm.model.diffuser import RayDiffuser from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT from diffusionsfm.model.scheduler import NoiseScheduler def load_model( output_dir, checkpoint=None, device="cuda:0", custom_keys=None, ignore_keys=() ): """ Loads a model and config from an output directory. E.g. to load with different number of images, ``` custom_keys={"model.num_images": 15}, ignore_keys=["pos_table"] ``` Args: output_dir (str): Path to the output directory. checkpoint (str or int): Path to the checkpoint to load. If None, loads the latest checkpoint. device (str): Device to load the model on. custom_keys (dict): Dictionary of custom keys to override in the config. """ # if checkpoint is None: # checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1] # else: # if isinstance(checkpoint, int): # checkpoint_name = f"ckpt_{checkpoint:08d}.pth" # else: # checkpoint_name = checkpoint # checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name) _URL = "https://huggingface.co/qitaoz/DiffusionSfM/resolve/main/ckpt_00800000.pth" data = torch.hub.load_state_dict_from_url(_URL) print("Loading checkpoint", _URL) cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml")) if custom_keys is not None: for k, v in custom_keys.items(): OmegaConf.update(cfg, k, v) noise_scheduler = NoiseScheduler( type=cfg.noise_scheduler.type, max_timesteps=cfg.noise_scheduler.max_timesteps, beta_start=cfg.noise_scheduler.beta_start, beta_end=cfg.noise_scheduler.beta_end, ) if not cfg.training.get("dpt_head", False): model = RayDiffuser( depth=cfg.model.depth, width=cfg.model.num_patches_x, P=1, max_num_images=cfg.model.num_images, noise_scheduler=noise_scheduler, feature_extractor=cfg.model.feature_extractor, append_ndc=cfg.model.append_ndc, diffuse_depths=cfg.training.get("diffuse_depths", False), depth_resolution=cfg.training.get("depth_resolution", 1), use_homogeneous=cfg.model.get("use_homogeneous", False), cond_depth_mask=cfg.model.get("cond_depth_mask", False), ).to(device) else: model = RayDiffuserDPT( depth=cfg.model.depth, width=cfg.model.num_patches_x, P=1, max_num_images=cfg.model.num_images, noise_scheduler=noise_scheduler, feature_extractor=cfg.model.feature_extractor, append_ndc=cfg.model.append_ndc, diffuse_depths=cfg.training.get("diffuse_depths", False), depth_resolution=cfg.training.get("depth_resolution", 1), encoder_features=cfg.training.get("dpt_encoder_features", False), use_homogeneous=cfg.model.get("use_homogeneous", False), cond_depth_mask=cfg.model.get("cond_depth_mask", False), ).to(device) # data = torch.load(checkpoint_path) state_dict = {} for k, v in data["state_dict"].items(): include = True for ignore_key in ignore_keys: if ignore_key in k: include = False if include: state_dict[k] = v missing, unexpected = model.load_state_dict(state_dict, strict=False) if len(missing) > 0: print("Missing keys:", missing) if len(unexpected) > 0: print("Unexpected keys:", unexpected) model = model.eval() return model, cfg