|
from __future__ import annotations |
|
|
|
import torch |
|
import torch.nn as nn |
|
from monai.utils import optional_import |
|
from torch.cuda.amp import autocast |
|
|
|
tqdm, has_tqdm = optional_import("tqdm", name="tqdm") |
|
|
|
|
|
class Sampler: |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
@torch.no_grad() |
|
def sampling_fn( |
|
self, |
|
noise: torch.Tensor, |
|
autoencoder_model: nn.Module, |
|
diffusion_model: nn.Module, |
|
scheduler: nn.Module, |
|
prompt_embeds: torch.Tensor, |
|
guidance_scale: float = 7.0, |
|
scale_factor: float = 0.3, |
|
) -> torch.Tensor: |
|
if has_tqdm: |
|
progress_bar = tqdm(scheduler.timesteps) |
|
else: |
|
progress_bar = iter(scheduler.timesteps) |
|
|
|
for t in progress_bar: |
|
noise_input = torch.cat([noise] * 2) |
|
model_output = diffusion_model( |
|
noise_input, timesteps=torch.Tensor((t,)).to(noise.device).long(), context=prompt_embeds |
|
) |
|
noise_pred_uncond, noise_pred_text = model_output.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
noise, _ = scheduler.step(noise_pred, t, noise) |
|
|
|
with autocast(): |
|
sample = autoencoder_model.decode_stage_2_outputs(noise / scale_factor) |
|
|
|
return sample |
|
|