Spaces:
Runtime error
Runtime error
from functools import partial | |
import torch | |
from opensora.registry import SCHEDULERS | |
from . import gaussian_diffusion as gd | |
from .respace import SpacedDiffusion, space_timesteps | |
class IDDPM(SpacedDiffusion): | |
def __init__( | |
self, | |
num_sampling_steps=None, | |
timestep_respacing=None, | |
noise_schedule="linear", | |
use_kl=False, | |
sigma_small=False, | |
predict_xstart=False, | |
learn_sigma=True, | |
rescale_learned_sigmas=False, | |
diffusion_steps=1000, | |
cfg_scale=4.0, | |
): | |
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) | |
if use_kl: | |
loss_type = gd.LossType.RESCALED_KL | |
elif rescale_learned_sigmas: | |
loss_type = gd.LossType.RESCALED_MSE | |
else: | |
loss_type = gd.LossType.MSE | |
if num_sampling_steps is not None: | |
assert timestep_respacing is None | |
timestep_respacing = str(num_sampling_steps) | |
if timestep_respacing is None or timestep_respacing == "": | |
timestep_respacing = [diffusion_steps] | |
super().__init__( | |
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), | |
betas=betas, | |
model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), | |
model_var_type=( | |
(gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) | |
if not learn_sigma | |
else gd.ModelVarType.LEARNED_RANGE | |
), | |
loss_type=loss_type, | |
# rescale_timesteps=rescale_timesteps, | |
) | |
self.cfg_scale = cfg_scale | |
def sample( | |
self, | |
model, | |
text_encoder, | |
z_size, | |
prompts, | |
device, | |
additional_args=None, | |
): | |
n = len(prompts) | |
z = torch.randn(n, *z_size, device=device) | |
z = torch.cat([z, z], 0) | |
model_args = text_encoder.encode(prompts) | |
y_null = text_encoder.null(n) | |
model_args["y"] = torch.cat([model_args["y"], y_null], 0) | |
if additional_args is not None: | |
model_args.update(additional_args) | |
forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale) | |
samples = self.p_sample_loop( | |
forward, | |
z.shape, | |
z, | |
clip_denoised=False, | |
model_kwargs=model_args, | |
progress=True, | |
device=device, | |
) | |
samples, _ = samples.chunk(2, dim=0) | |
return samples | |
def forward_with_cfg(model, x, timestep, y, cfg_scale, **kwargs): | |
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb | |
half = x[: len(x) // 2] | |
combined = torch.cat([half, half], dim=0) | |
model_out = model.forward(combined, timestep, y, **kwargs) | |
model_out = model_out["x"] if isinstance(model_out, dict) else model_out | |
eps, rest = model_out[:, :3], model_out[:, 3:] | |
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) | |
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) | |
eps = torch.cat([half_eps, half_eps], dim=0) | |
return torch.cat([eps, rest], dim=1) | |