import torch from torch import Tensor from transformers import PreTrainedModel from audio_diffusion_pytorch import DiffusionVocoder, UNetV0, VDiffusion, VSampler from .config import VocoderConfig class Vocoder(PreTrainedModel): config_class = VocoderConfig def __init__(self, config: VocoderConfig): super().__init__(config) self.model = DiffusionVocoder( net_t=UNetV0, mel_channels=80, mel_n_fft=1024, mel_sample_rate=48000, mel_normalize_log=True, channels=[8, 32, 64, 256, 256, 512, 512, 1024, 1024], factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], items=[1, 2, 2, 2, 2, 2, 2, 4, 4], diffusion_t=VDiffusion, sampler_t=VSampler ) def to_spectrogram(self, *args, **kwargs): return self.model.to_spectrogram(*args, **kwargs) @torch.no_grad() def sample(self, *args, **kwargs): return self.model.sample(*args, **kwargs)