Spaces:
Runtime error
Runtime error
from typing import Any, Callable, Dict, Optional, List | |
import torch | |
import torch.nn as nn | |
from .gaussian_diffusion import GaussianDiffusion | |
from .k_diffusion import karras_sample, karras_sample_addition_condition | |
DEFAULT_KARRAS_STEPS = 64 | |
DEFAULT_KARRAS_SIGMA_MIN = 1e-3 | |
DEFAULT_KARRAS_SIGMA_MAX = 160 | |
DEFAULT_KARRAS_S_CHURN = 0.0 | |
def uncond_guide_model( | |
model: Callable[..., torch.Tensor], scale: float | |
) -> Callable[..., torch.Tensor]: | |
def model_fn(x_t, ts, **kwargs): | |
half = x_t[: len(x_t) // 2] | |
combined = torch.cat([half, half], dim=0) | |
model_out = model(combined, ts, **kwargs) | |
cond_out, uncond_out = torch.chunk(model_out, 2, dim=0) | |
cond_out = uncond_out + scale * (cond_out - uncond_out) | |
return torch.cat([cond_out, cond_out], dim=0) | |
return model_fn | |
def sample_latents( | |
*, | |
batch_size: int, | |
model: nn.Module, | |
diffusion: GaussianDiffusion, | |
model_kwargs: Dict[str, Any], | |
guidance_scale: float, | |
clip_denoised: bool, | |
use_fp16: bool, | |
use_karras: bool, | |
karras_steps: int, | |
sigma_min: float, | |
sigma_max: float, | |
s_churn: float, | |
device: Optional[torch.device] = None, | |
progress: bool = False, | |
initial_noise: Optional[torch.Tensor] = None, | |
) -> (torch.Tensor, List[torch.Tensor]): | |
sample_shape = (batch_size, model.d_latent) | |
if device is None: | |
device = next(model.parameters()).device | |
if hasattr(model, "cached_model_kwargs"): | |
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs) | |
if guidance_scale != 1.0 and guidance_scale != 0.0: | |
for k, v in model_kwargs.copy().items(): | |
# print(k, v.shape) | |
model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0) | |
sample_shape = (batch_size, model.d_latent) | |
with torch.autocast(device_type=device.type, enabled=use_fp16): | |
if use_karras: | |
samples, sample_sequence = karras_sample( | |
diffusion=diffusion, | |
model=model, | |
shape=sample_shape, | |
steps=karras_steps, | |
clip_denoised=clip_denoised, | |
model_kwargs=model_kwargs, | |
device=device, | |
sigma_min=sigma_min, | |
sigma_max=sigma_max, | |
s_churn=s_churn, | |
guidance_scale=guidance_scale, | |
progress=progress, | |
initial_noise=initial_noise, | |
) | |
else: | |
internal_batch_size = batch_size | |
if guidance_scale != 1.0: | |
model = uncond_guide_model(model, guidance_scale) | |
internal_batch_size *= 2 | |
samples = diffusion.p_sample_loop( | |
model, | |
shape=(internal_batch_size, *sample_shape[1:]), | |
model_kwargs=model_kwargs, | |
device=device, | |
clip_denoised=clip_denoised, | |
progress=progress, | |
) | |
return samples | |
def sample_latents_with_additional_latent( | |
*, | |
batch_size: int, | |
model: nn.Module, | |
diffusion: GaussianDiffusion, | |
model_kwargs: Dict[str, Any], | |
text_guidance_scale: float, | |
image_guidance_scale: float, | |
clip_denoised: bool, | |
use_fp16: bool, | |
use_karras: bool, | |
karras_steps: int, | |
sigma_min: float, | |
sigma_max: float, | |
s_churn: float, | |
device: Optional[torch.device] = None, | |
progress: bool = False, | |
condition_latent: Optional[torch.Tensor] = None, | |
) -> (torch.Tensor, List[torch.Tensor]): | |
if device is None: | |
device = next(model.parameters()).device | |
if hasattr(model, "cached_model_kwargs"): | |
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs) | |
if (text_guidance_scale != 1.0 and text_guidance_scale != 0.0) or (image_guidance_scale != 1.0 and image_guidance_scale != 0.0): | |
for k, v in model_kwargs.copy().items(): | |
# print(k, v.shape) | |
model_kwargs[k] = torch.cat([v, torch.zeros_like(v), torch.zeros_like(v)], dim=0) | |
condition_latent = torch.cat([condition_latent, condition_latent, torch.zeros_like(condition_latent)], dim=0) | |
sample_shape = (batch_size, model.d_latent) | |
# print("sample_shape", sample_shape) | |
with torch.autocast(device_type=device.type, enabled=use_fp16): | |
if use_karras: | |
samples, samples_squence = karras_sample_addition_condition( | |
diffusion=diffusion, | |
model=model, | |
shape=sample_shape, | |
steps=karras_steps, | |
clip_denoised=clip_denoised, | |
model_kwargs=model_kwargs, | |
device=device, | |
sigma_min=sigma_min, | |
sigma_max=sigma_max, | |
s_churn=s_churn, | |
text_guidance_scale=text_guidance_scale, | |
image_guidance_scale=image_guidance_scale, | |
progress=progress, | |
condition_latent=condition_latent, | |
) | |
else: | |
internal_batch_size = batch_size | |
if text_guidance_scale != 1.0: | |
model = uncond_guide_model(model, text_guidance_scale) | |
internal_batch_size *= 2 | |
samples = diffusion.p_sample_loop( | |
model, | |
shape=(internal_batch_size, *sample_shape[1:]), | |
model_kwargs=model_kwargs, | |
device=device, | |
clip_denoised=clip_denoised, | |
progress=progress, | |
) | |
return samples |