Spaces:
Paused
Paused
import torch | |
from omegaconf import OmegaConf | |
from sgm.util import instantiate_from_config | |
from sgm.modules.diffusionmodules.sampling import * | |
def init_model(cfgs): | |
model_cfg = OmegaConf.load(cfgs.model_cfg_path) | |
ckpt = cfgs.load_ckpt_path | |
model = instantiate_from_config(model_cfg.model) | |
model.init_from_ckpt(ckpt) | |
if cfgs.type == "train": | |
model.train() | |
else: | |
model.to(torch.device("cuda", index=cfgs.gpu)) | |
model.eval() | |
model.freeze() | |
return model | |
def init_sampling(cfgs): | |
discretization_config = { | |
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", | |
} | |
guider_config = { | |
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", | |
"params": {"scale": cfgs.scale[0]}, | |
} | |
sampler = EulerEDMSampler( | |
num_steps=cfgs.steps, | |
discretization_config=discretization_config, | |
guider_config=guider_config, | |
s_churn=0.0, | |
s_tmin=0.0, | |
s_tmax=999.0, | |
s_noise=1.0, | |
verbose=True, | |
device=torch.device("cuda", index=cfgs.gpu) | |
) | |
return sampler | |
def deep_copy(batch): | |
c_batch = {} | |
for key in batch: | |
if isinstance(batch[key], torch.Tensor): | |
c_batch[key] = torch.clone(batch[key]) | |
elif isinstance(batch[key], (tuple, list)): | |
c_batch[key] = batch[key].copy() | |
else: | |
c_batch[key] = batch[key] | |
return c_batch | |
def prepare_batch(cfgs, batch): | |
for key in batch: | |
if isinstance(batch[key], torch.Tensor): | |
batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu)) | |
batch_uc = batch | |
return batch, batch_uc |