silentchen's picture
first commit
19c4ddf
raw
history blame
No virus
5.53 kB
from typing import Any, Callable, Dict, Optional, List
import torch
import torch.nn as nn
from .gaussian_diffusion import GaussianDiffusion
from .k_diffusion import karras_sample, karras_sample_addition_condition
DEFAULT_KARRAS_STEPS = 64
DEFAULT_KARRAS_SIGMA_MIN = 1e-3
DEFAULT_KARRAS_SIGMA_MAX = 160
DEFAULT_KARRAS_S_CHURN = 0.0
def uncond_guide_model(
model: Callable[..., torch.Tensor], scale: float
) -> Callable[..., torch.Tensor]:
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = model(combined, ts, **kwargs)
cond_out, uncond_out = torch.chunk(model_out, 2, dim=0)
cond_out = uncond_out + scale * (cond_out - uncond_out)
return torch.cat([cond_out, cond_out], dim=0)
return model_fn
def sample_latents(
*,
batch_size: int,
model: nn.Module,
diffusion: GaussianDiffusion,
model_kwargs: Dict[str, Any],
guidance_scale: float,
clip_denoised: bool,
use_fp16: bool,
use_karras: bool,
karras_steps: int,
sigma_min: float,
sigma_max: float,
s_churn: float,
device: Optional[torch.device] = None,
progress: bool = False,
initial_noise: Optional[torch.Tensor] = None,
) -> (torch.Tensor, List[torch.Tensor]):
sample_shape = (batch_size, model.d_latent)
if device is None:
device = next(model.parameters()).device
if hasattr(model, "cached_model_kwargs"):
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
if guidance_scale != 1.0 and guidance_scale != 0.0:
for k, v in model_kwargs.copy().items():
# print(k, v.shape)
model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
sample_shape = (batch_size, model.d_latent)
with torch.autocast(device_type=device.type, enabled=use_fp16):
if use_karras:
samples, sample_sequence = karras_sample(
diffusion=diffusion,
model=model,
shape=sample_shape,
steps=karras_steps,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
device=device,
sigma_min=sigma_min,
sigma_max=sigma_max,
s_churn=s_churn,
guidance_scale=guidance_scale,
progress=progress,
initial_noise=initial_noise,
)
else:
internal_batch_size = batch_size
if guidance_scale != 1.0:
model = uncond_guide_model(model, guidance_scale)
internal_batch_size *= 2
samples = diffusion.p_sample_loop(
model,
shape=(internal_batch_size, *sample_shape[1:]),
model_kwargs=model_kwargs,
device=device,
clip_denoised=clip_denoised,
progress=progress,
)
return samples
def sample_latents_with_additional_latent(
*,
batch_size: int,
model: nn.Module,
diffusion: GaussianDiffusion,
model_kwargs: Dict[str, Any],
text_guidance_scale: float,
image_guidance_scale: float,
clip_denoised: bool,
use_fp16: bool,
use_karras: bool,
karras_steps: int,
sigma_min: float,
sigma_max: float,
s_churn: float,
device: Optional[torch.device] = None,
progress: bool = False,
condition_latent: Optional[torch.Tensor] = None,
) -> (torch.Tensor, List[torch.Tensor]):
if device is None:
device = next(model.parameters()).device
if hasattr(model, "cached_model_kwargs"):
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
if (text_guidance_scale != 1.0 and text_guidance_scale != 0.0) or (image_guidance_scale != 1.0 and image_guidance_scale != 0.0):
for k, v in model_kwargs.copy().items():
# print(k, v.shape)
model_kwargs[k] = torch.cat([v, torch.zeros_like(v), torch.zeros_like(v)], dim=0)
condition_latent = torch.cat([condition_latent, condition_latent, torch.zeros_like(condition_latent)], dim=0)
sample_shape = (batch_size, model.d_latent)
# print("sample_shape", sample_shape)
with torch.autocast(device_type=device.type, enabled=use_fp16):
if use_karras:
samples, samples_squence = karras_sample_addition_condition(
diffusion=diffusion,
model=model,
shape=sample_shape,
steps=karras_steps,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
device=device,
sigma_min=sigma_min,
sigma_max=sigma_max,
s_churn=s_churn,
text_guidance_scale=text_guidance_scale,
image_guidance_scale=image_guidance_scale,
progress=progress,
condition_latent=condition_latent,
)
else:
internal_batch_size = batch_size
if text_guidance_scale != 1.0:
model = uncond_guide_model(model, text_guidance_scale)
internal_batch_size *= 2
samples = diffusion.p_sample_loop(
model,
shape=(internal_batch_size, *sample_shape[1:]),
model_kwargs=model_kwargs,
device=device,
clip_denoised=clip_denoised,
progress=progress,
)
return samples