UDiffText / util.py
ZYMPKU's picture
v1
ed25868
raw
history blame
1.71 kB
import torch
from omegaconf import OmegaConf
from sgm.util import instantiate_from_config
from sgm.modules.diffusionmodules.sampling import *
def init_model(cfgs):
model_cfg = OmegaConf.load(cfgs.model_cfg_path)
ckpt = cfgs.load_ckpt_path
model = instantiate_from_config(model_cfg.model)
model.init_from_ckpt(ckpt)
if cfgs.type == "train":
model.train()
else:
model.to(torch.device("cuda", index=cfgs.gpu))
model.eval()
model.freeze()
return model
def init_sampling(cfgs):
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": cfgs.scale[0]},
}
sampler = EulerEDMSampler(
num_steps=cfgs.steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=0.0,
s_tmin=0.0,
s_tmax=999.0,
s_noise=1.0,
verbose=True,
device=torch.device("cuda", index=cfgs.gpu)
)
return sampler
def deep_copy(batch):
c_batch = {}
for key in batch:
if isinstance(batch[key], torch.Tensor):
c_batch[key] = torch.clone(batch[key])
elif isinstance(batch[key], (tuple, list)):
c_batch[key] = batch[key].copy()
else:
c_batch[key] = batch[key]
return c_batch
def prepare_batch(cfgs, batch):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
batch_uc = batch
return batch, batch_uc