|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
class DiffusionInferencePipeline(DiffusionPipeline): |
|
def __init__(self, network, scheduler, num_inference_timesteps=1000): |
|
super().__init__() |
|
|
|
self.register_modules(network=network, scheduler=scheduler) |
|
self.num_inference_timesteps = num_inference_timesteps |
|
|
|
@torch.inference_mode() |
|
def __call__( |
|
self, |
|
initial_noise: torch.Tensor, |
|
conditioner: torch.Tensor = None, |
|
): |
|
r""" |
|
Args: |
|
initial_noise: The initial noise to be denoised. |
|
conditioner:The conditioner. |
|
n_inference_steps: The number of denoising steps. More denoising steps |
|
usually lead to a higher quality at the expense of slower inference. |
|
""" |
|
|
|
mel = initial_noise |
|
batch_size = mel.size(0) |
|
self.scheduler.set_timesteps(self.num_inference_timesteps) |
|
|
|
for t in self.progress_bar(self.scheduler.timesteps): |
|
timestep = torch.full((batch_size,), t, device=mel.device, dtype=torch.long) |
|
|
|
|
|
model_output = self.network(mel, timestep, conditioner) |
|
|
|
|
|
mel = self.scheduler.step(model_output, t, mel).prev_sample |
|
|
|
|
|
mel = mel.clamp(-1.0, 1.0) |
|
|
|
return mel |
|
|