# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch from diffusers import DiffusionPipeline class DiffusionInferencePipeline(DiffusionPipeline): def __init__(self, network, scheduler, num_inference_timesteps=1000): super().__init__() self.register_modules(network=network, scheduler=scheduler) self.num_inference_timesteps = num_inference_timesteps @torch.inference_mode() def __call__( self, initial_noise: torch.Tensor, conditioner: torch.Tensor = None, ): r""" Args: initial_noise: The initial noise to be denoised. conditioner:The conditioner. n_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality at the expense of slower inference. """ mel = initial_noise batch_size = mel.size(0) self.scheduler.set_timesteps(self.num_inference_timesteps) for t in self.progress_bar(self.scheduler.timesteps): timestep = torch.full((batch_size,), t, device=mel.device, dtype=torch.long) # 1. predict noise model_output model_output = self.network(mel, timestep, conditioner) # 2. denoise, compute previous step: x_t -> x_t-1 mel = self.scheduler.step(model_output, t, mel).prev_sample # 3. clamp mel = mel.clamp(-1.0, 1.0) return mel