File size: 3,782 Bytes
4562a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2ecda1
 
 
 
 
 
 
 
 
 
 
4562a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2ecda1
4562a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os.path as osp
from glob import glob

import torch
from omegaconf import OmegaConf

from diffusionsfm.model.diffuser import RayDiffuser
from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT
from diffusionsfm.model.scheduler import NoiseScheduler


def load_model(
    output_dir, checkpoint=None, device="cuda:0", custom_keys=None, ignore_keys=()
):
    """
    Loads a model and config from an output directory.

    E.g. to load with different number of images,
    ```
    custom_keys={"model.num_images": 15}, ignore_keys=["pos_table"]
    ```

    Args:
        output_dir (str): Path to the output directory.
        checkpoint (str or int): Path to the checkpoint to load. If None, loads the
            latest checkpoint.
        device (str): Device to load the model on.
        custom_keys (dict): Dictionary of custom keys to override in the config.
    """
    # if checkpoint is None:
    #     checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1]
    # else:
    #     if isinstance(checkpoint, int):
    #         checkpoint_name = f"ckpt_{checkpoint:08d}.pth"
    #     else:
    #         checkpoint_name = checkpoint
    #     checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name)
    _URL = "https://huggingface.co/qitaoz/DiffusionSfM/resolve/main/ckpt_00800000.pth"
    data = torch.hub.load_state_dict_from_url(_URL)
    print("Loading checkpoint", _URL)

    cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml"))
    if custom_keys is not None:
        for k, v in custom_keys.items():
            OmegaConf.update(cfg, k, v)
    noise_scheduler = NoiseScheduler(
        type=cfg.noise_scheduler.type,
        max_timesteps=cfg.noise_scheduler.max_timesteps,
        beta_start=cfg.noise_scheduler.beta_start,
        beta_end=cfg.noise_scheduler.beta_end,
    )

    if not cfg.training.get("dpt_head", False):
        model = RayDiffuser(
            depth=cfg.model.depth,
            width=cfg.model.num_patches_x,
            P=1,
            max_num_images=cfg.model.num_images,
            noise_scheduler=noise_scheduler,
            feature_extractor=cfg.model.feature_extractor,
            append_ndc=cfg.model.append_ndc,
            diffuse_depths=cfg.training.get("diffuse_depths", False),
            depth_resolution=cfg.training.get("depth_resolution", 1),
            use_homogeneous=cfg.model.get("use_homogeneous", False),
            cond_depth_mask=cfg.model.get("cond_depth_mask", False),
        ).to(device)
    else:
        model = RayDiffuserDPT(
            depth=cfg.model.depth,
            width=cfg.model.num_patches_x,
            P=1,
            max_num_images=cfg.model.num_images,
            noise_scheduler=noise_scheduler,
            feature_extractor=cfg.model.feature_extractor,
            append_ndc=cfg.model.append_ndc,
            diffuse_depths=cfg.training.get("diffuse_depths", False),
            depth_resolution=cfg.training.get("depth_resolution", 1),
            encoder_features=cfg.training.get("dpt_encoder_features", False),
            use_homogeneous=cfg.model.get("use_homogeneous", False),
            cond_depth_mask=cfg.model.get("cond_depth_mask", False),
        ).to(device)

    # data = torch.load(checkpoint_path)
    state_dict = {}
    for k, v in data["state_dict"].items():
        include = True
        for ignore_key in ignore_keys:
            if ignore_key in k:
                include = False
        if include:
            state_dict[k] = v

    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    if len(missing) > 0:
        print("Missing keys:", missing)
    if len(unexpected) > 0:
        print("Unexpected keys:", unexpected)
    model = model.eval()
    return model, cfg