Spaces:
Running
on
T4
Running
on
T4
File size: 3,782 Bytes
4562a06 d2ecda1 4562a06 d2ecda1 4562a06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import os.path as osp
from glob import glob
import torch
from omegaconf import OmegaConf
from diffusionsfm.model.diffuser import RayDiffuser
from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT
from diffusionsfm.model.scheduler import NoiseScheduler
def load_model(
output_dir, checkpoint=None, device="cuda:0", custom_keys=None, ignore_keys=()
):
"""
Loads a model and config from an output directory.
E.g. to load with different number of images,
```
custom_keys={"model.num_images": 15}, ignore_keys=["pos_table"]
```
Args:
output_dir (str): Path to the output directory.
checkpoint (str or int): Path to the checkpoint to load. If None, loads the
latest checkpoint.
device (str): Device to load the model on.
custom_keys (dict): Dictionary of custom keys to override in the config.
"""
# if checkpoint is None:
# checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1]
# else:
# if isinstance(checkpoint, int):
# checkpoint_name = f"ckpt_{checkpoint:08d}.pth"
# else:
# checkpoint_name = checkpoint
# checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name)
_URL = "https://huggingface.co/qitaoz/DiffusionSfM/resolve/main/ckpt_00800000.pth"
data = torch.hub.load_state_dict_from_url(_URL)
print("Loading checkpoint", _URL)
cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml"))
if custom_keys is not None:
for k, v in custom_keys.items():
OmegaConf.update(cfg, k, v)
noise_scheduler = NoiseScheduler(
type=cfg.noise_scheduler.type,
max_timesteps=cfg.noise_scheduler.max_timesteps,
beta_start=cfg.noise_scheduler.beta_start,
beta_end=cfg.noise_scheduler.beta_end,
)
if not cfg.training.get("dpt_head", False):
model = RayDiffuser(
depth=cfg.model.depth,
width=cfg.model.num_patches_x,
P=1,
max_num_images=cfg.model.num_images,
noise_scheduler=noise_scheduler,
feature_extractor=cfg.model.feature_extractor,
append_ndc=cfg.model.append_ndc,
diffuse_depths=cfg.training.get("diffuse_depths", False),
depth_resolution=cfg.training.get("depth_resolution", 1),
use_homogeneous=cfg.model.get("use_homogeneous", False),
cond_depth_mask=cfg.model.get("cond_depth_mask", False),
).to(device)
else:
model = RayDiffuserDPT(
depth=cfg.model.depth,
width=cfg.model.num_patches_x,
P=1,
max_num_images=cfg.model.num_images,
noise_scheduler=noise_scheduler,
feature_extractor=cfg.model.feature_extractor,
append_ndc=cfg.model.append_ndc,
diffuse_depths=cfg.training.get("diffuse_depths", False),
depth_resolution=cfg.training.get("depth_resolution", 1),
encoder_features=cfg.training.get("dpt_encoder_features", False),
use_homogeneous=cfg.model.get("use_homogeneous", False),
cond_depth_mask=cfg.model.get("cond_depth_mask", False),
).to(device)
# data = torch.load(checkpoint_path)
state_dict = {}
for k, v in data["state_dict"].items():
include = True
for ignore_key in ignore_keys:
if ignore_key in k:
include = False
if include:
state_dict[k] = v
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if len(missing) > 0:
print("Missing keys:", missing)
if len(unexpected) > 0:
print("Unexpected keys:", unexpected)
model = model.eval()
return model, cfg
|