File size: 1,710 Bytes
9e7a39a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
from omegaconf import OmegaConf
from sgm.util import instantiate_from_config
from sgm.modules.diffusionmodules.sampling import *
def init_model(cfgs):
model_cfg = OmegaConf.load(cfgs.model_cfg_path)
ckpt = cfgs.load_ckpt_path
model = instantiate_from_config(model_cfg.model)
model.init_from_ckpt(ckpt)
if cfgs.type == "train":
model.train()
else:
model.to(torch.device("cuda", index=cfgs.gpu))
model.eval()
model.freeze()
return model
def init_sampling(cfgs):
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": cfgs.scale[0]},
}
sampler = EulerEDMSampler(
num_steps=cfgs.steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=0.0,
s_tmin=0.0,
s_tmax=999.0,
s_noise=1.0,
verbose=True,
device=torch.device("cuda", index=cfgs.gpu)
)
return sampler
def deep_copy(batch):
c_batch = {}
for key in batch:
if isinstance(batch[key], torch.Tensor):
c_batch[key] = torch.clone(batch[key])
elif isinstance(batch[key], (tuple, list)):
c_batch[key] = batch[key].copy()
else:
c_batch[key] = batch[key]
return c_batch
def prepare_batch(cfgs, batch):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
batch_uc = batch
return batch, batch_uc |