File size: 1,566 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from diffusers import DiffusionPipeline
class DiffusionInferencePipeline(DiffusionPipeline):
def __init__(self, network, scheduler, num_inference_timesteps=1000):
super().__init__()
self.register_modules(network=network, scheduler=scheduler)
self.num_inference_timesteps = num_inference_timesteps
@torch.inference_mode()
def __call__(
self,
initial_noise: torch.Tensor,
conditioner: torch.Tensor = None,
):
r"""
Args:
initial_noise: The initial noise to be denoised.
conditioner:The conditioner.
n_inference_steps: The number of denoising steps. More denoising steps
usually lead to a higher quality at the expense of slower inference.
"""
mel = initial_noise
batch_size = mel.size(0)
self.scheduler.set_timesteps(self.num_inference_timesteps)
for t in self.progress_bar(self.scheduler.timesteps):
timestep = torch.full((batch_size,), t, device=mel.device, dtype=torch.long)
# 1. predict noise model_output
model_output = self.network(mel, timestep, conditioner)
# 2. denoise, compute previous step: x_t -> x_t-1
mel = self.scheduler.step(model_output, t, mel).prev_sample
# 3. clamp
mel = mel.clamp(-1.0, 1.0)
return mel
|