File size: 1,349 Bytes
f7ea149 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from __future__ import annotations
import torch
import torch.nn as nn
from monai.utils import optional_import
from torch.cuda.amp import autocast
tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
class Sampler:
def __init__(self) -> None:
super().__init__()
@torch.no_grad()
def sampling_fn(
self,
noise: torch.Tensor,
autoencoder_model: nn.Module,
diffusion_model: nn.Module,
scheduler: nn.Module,
prompt_embeds: torch.Tensor,
guidance_scale: float = 7.0,
scale_factor: float = 0.3,
) -> torch.Tensor:
if has_tqdm:
progress_bar = tqdm(scheduler.timesteps)
else:
progress_bar = iter(scheduler.timesteps)
for t in progress_bar:
noise_input = torch.cat([noise] * 2)
model_output = diffusion_model(
noise_input, timesteps=torch.Tensor((t,)).to(noise.device).long(), context=prompt_embeds
)
noise_pred_uncond, noise_pred_text = model_output.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise, _ = scheduler.step(noise_pred, t, noise)
with autocast():
sample = autoencoder_model.decode_stage_2_outputs(noise / scale_factor)
return sample
|