UDiffText / util.py
ZYMPKU's picture
first
6497501
raw
history blame
No virus
3.63 kB
import torch
from omegaconf import OmegaConf
from sgm.util import instantiate_from_config
from sgm.modules.diffusionmodules.sampling import *
SD_XL_BASE_RATIOS = {
"0.5": (704, 1408),
"0.52": (704, 1344),
"0.57": (768, 1344),
"0.6": (768, 1280),
"0.68": (832, 1216),
"0.72": (832, 1152),
"0.78": (896, 1152),
"0.82": (896, 1088),
"0.88": (960, 1088),
"0.94": (960, 1024),
"1.0": (1024, 1024),
"1.07": (1024, 960),
"1.13": (1088, 960),
"1.21": (1088, 896),
"1.29": (1152, 896),
"1.38": (1152, 832),
"1.46": (1216, 832),
"1.67": (1280, 768),
"1.75": (1344, 768),
"1.91": (1344, 704),
"2.0": (1408, 704),
"2.09": (1472, 704),
"2.4": (1536, 640),
"2.5": (1600, 640),
"2.89": (1664, 576),
"3.0": (1728, 576),
}
def init_model(cfg):
model_cfg = OmegaConf.load(cfg.model_cfg_path)
ckpt = cfg.load_ckpt_path
model = instantiate_from_config(model_cfg.model)
model.init_from_ckpt(ckpt)
if cfg.type == "train":
model.train()
else:
model.to(torch.device("cuda", index=cfg.gpu))
model.eval()
model.freeze()
return model
def init_sampling(cfgs):
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
if cfgs.dual_conditioner:
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.DualCFG",
"params": {"scale": cfgs.scale},
}
sampler = EulerEDMDualSampler(
num_steps=cfgs.steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=0.0,
s_tmin=0.0,
s_tmax=999.0,
s_noise=1.0,
verbose=True,
device=torch.device("cuda", index=cfgs.gpu)
)
else:
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": cfgs.scale[0]},
}
sampler = EulerEDMSampler(
num_steps=cfgs.steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=0.0,
s_tmin=0.0,
s_tmax=999.0,
s_noise=1.0,
verbose=True,
device=torch.device("cuda", index=cfgs.gpu)
)
return sampler
def deep_copy(batch):
c_batch = {}
for key in batch:
if isinstance(batch[key], torch.Tensor):
c_batch[key] = torch.clone(batch[key])
elif isinstance(batch[key], (tuple, list)):
c_batch[key] = batch[key].copy()
else:
c_batch[key] = batch[key]
return c_batch
def prepare_batch(cfgs, batch):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
if not cfgs.dual_conditioner:
batch_uc = deep_copy(batch)
if "ntxt" in batch:
batch_uc["txt"] = batch["ntxt"]
else:
batch_uc["txt"] = ["" for _ in range(len(batch["txt"]))]
if "label" in batch:
batch_uc["label"] = ["" for _ in range(len(batch["label"]))]
return batch, batch_uc, None
else:
batch_uc_1 = deep_copy(batch)
batch_uc_2 = deep_copy(batch)
batch_uc_1["ref"] = torch.zeros_like(batch["ref"])
batch_uc_2["ref"] = torch.zeros_like(batch["ref"])
batch_uc_1["label"] = ["" for _ in range(len(batch["label"]))]
return batch, batch_uc_1, batch_uc_2