File size: 5,001 Bytes
7b127f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
from pathlib import Path
from typing import Optional
from lightning_fabric import seed_everything
from huggingface_hub import hf_hub_download

from diffusion.diffusion_model import Diffusion_condition


'''
Steps to make a model:
1. Set up the model structure depending on the modal
2. Set up AutoEncoder weights
3. Set up Diffusor weights
*. Set up the condition flag (Should be deleted in the future)
4. Pick a random seed
**. Designate the output folder (Also should be deleted in the future, this is not the responsibility of a model!)
'''
class ModelBuilder():
    NUM_PROPOSALS = 32
    def __init__(self):
        self.reset()

    def set_up_model_template(self, model_class: Diffusion_condition):
        # This shouldn't exist due to the Diffusion_condition's inheritence
        self._model_class = model_class
    
    # Theoretically, this function should be the true Builder API
    # def set_up_modal(self, modal: Diffusion_condition):
    #     # Set up the modal for the model(pc, txt, sketch, svr, mvr)
    #     self._model_instance = modal
    
    def setup_autoencoder_weights(self, weights_path: Path | str):
        self._config["autoencoder_weights"] = weights_path
        
    def setup_diffusion_weights(self, weights_path: Path | str):
        self._config["diffusion_weights"] = weights_path
        
    def setup_condition(self, condition: str):
        self._config["condition"] = [condition]
       
    def setup_seed(self, seed: Optional[int] = None):
        if seed is not None:
            seed_everything(seed)
        else:
            seed_everything(0)
            
    def setup_output_dir(self, output_dir: Path | str):
        self._config["output_dir"] = output_dir
    
    def make_model(self, device: Optional[torch.device] = None):
        # Torch condition
        torch.backends.cudnn.benchmark = False
        torch.set_float32_matmul_precision("medium")

        # Device
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            
        # Set up modal for the model
        # (Need to be refactored in the future)
        self._model_instance = self._model_class(self._config)
        
        # Load diffusion weights
        repo_id = Path(self._config["diffusion_weights"]).parent.as_posix()
        model_name = Path(self._config["diffusion_weights"]).name
        model_weights = hf_hub_download(repo_id=repo_id, filename=model_name)
        diffusion_weights = torch.load(model_weights, map_location=device, weights_only=False)["state_dict"]
        diffusion_weights = {k: v for k, v in diffusion_weights.items() if "ae_model" not in k}
        diffusion_weights = {k[6:]: v for k, v in diffusion_weights.items() if "model" in k}
        
        # Load Autoencoder weights
        AE_repo_id = Path(self._config["autoencoder_weights"]).parent.as_posix()
        AE_model_name = Path(self._config["autoencoder_weights"]).name
        AE_model_weights = hf_hub_download(repo_id=AE_repo_id, filename=AE_model_name)
        autoencoder_weights = torch.load(AE_model_weights, map_location=device, weights_only=False)["state_dict"]
        autoencoder_weights = {k[6:]: v for k, v in autoencoder_weights.items() if "model" in k}
        autoencoder_weights = {"ae_model."+k: v for k, v in autoencoder_weights.items()}
        
        # Combine ae with diffusor
        diffusion_weights.update(autoencoder_weights)
        diffusion_weights = {k: v for k, v in diffusion_weights.items() if "camera_embedding" not in k}
        
        self._model_instance.load_state_dict(diffusion_weights, strict=False)
        self._model_instance.to(device)
        self._model_instance.eval()
        
        return self._model_instance
    
    def reset(self):
        self._model_class = Diffusion_condition # This shouldn't exist. See set_up_model_template()
        self._model_instance = None 
        # Basic model config()
        self._config = {
            "name": "Diffusion_condition",
            "train_decoder": False,
            "stored_z": False,
            "use_mean": True,
            "diffusion_latent": 768,
            "diffusion_type": "epsilon",
            "loss": "l2",
            "pad_method": "random",
            "num_max_faces": 30,
            "beta_schedule": "squaredcos_cap_v2",
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "variance_type": "fixed_small",
            "addition_tag": False,
            "autoencoder": "AutoEncoder_1119_light",
            "with_intersection": True,
            "dim_latent": 8,
            "dim_shape": 768,
            "sigmoid": False,
            "in_channels": 6,
            "gaussian_weights": 1e-6,
            "norm": "layer",
            "autoencoder_weights": "",
            "is_aug": False,
            "condition": [],
            "cond_prob": []
        }   
    
    @property
    def model(self):
        model = self._model_instance
        return model