|
import torch |
|
from torch import Tensor |
|
from transformers import PreTrainedModel |
|
from audio_diffusion_pytorch import DiffusionVocoder, UNetV0, VDiffusion, VSampler |
|
from .config import VocoderConfig |
|
|
|
|
|
class Vocoder(PreTrainedModel): |
|
|
|
config_class = VocoderConfig |
|
|
|
def __init__(self, config: VocoderConfig): |
|
super().__init__(config) |
|
|
|
self.model = DiffusionVocoder( |
|
net_t=UNetV0, |
|
mel_channels=80, |
|
mel_n_fft=1024, |
|
mel_sample_rate=48000, |
|
mel_normalize_log=True, |
|
channels=[8, 32, 64, 256, 256, 512, 512, 1024, 1024], |
|
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], |
|
items=[1, 2, 2, 2, 2, 2, 2, 4, 4], |
|
diffusion_t=VDiffusion, |
|
sampler_t=VSampler |
|
) |
|
|
|
def to_spectrogram(self, *args, **kwargs): |
|
return self.model.to_spectrogram(*args, **kwargs) |
|
|
|
@torch.no_grad() |
|
def sample(self, *args, **kwargs): |
|
return self.model.sample(*args, **kwargs) |
|
|
|
|
|
|