File size: 1,710 Bytes
9e7a39a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from omegaconf import OmegaConf
from sgm.util import instantiate_from_config
from sgm.modules.diffusionmodules.sampling import *


def init_model(cfgs):

    model_cfg = OmegaConf.load(cfgs.model_cfg_path)
    ckpt = cfgs.load_ckpt_path

    model = instantiate_from_config(model_cfg.model)
    model.init_from_ckpt(ckpt)

    if cfgs.type == "train":
        model.train()
    else:
        model.to(torch.device("cuda", index=cfgs.gpu))
        model.eval()
        model.freeze()

    return model

def init_sampling(cfgs):

    discretization_config = {
        "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
    }

    guider_config = {
        "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
        "params": {"scale": cfgs.scale[0]},
    }

    sampler = EulerEDMSampler(
        num_steps=cfgs.steps,
        discretization_config=discretization_config,
        guider_config=guider_config,
        s_churn=0.0,
        s_tmin=0.0,
        s_tmax=999.0,
        s_noise=1.0,
        verbose=True,
        device=torch.device("cuda", index=cfgs.gpu)
    )

    return sampler

def deep_copy(batch):

    c_batch = {}
    for key in batch:
        if isinstance(batch[key], torch.Tensor):
            c_batch[key] = torch.clone(batch[key])
        elif isinstance(batch[key], (tuple, list)): 
            c_batch[key] = batch[key].copy()
        else:
            c_batch[key] = batch[key]
    
    return c_batch

def prepare_batch(cfgs, batch):

    for key in batch:
        if isinstance(batch[key], torch.Tensor):
            batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))

    batch_uc = batch

    return batch, batch_uc