vocoder-v1 / model.py
flavioschneider's picture
Upload Vocoder
b8272e6
raw
history blame
1.01 kB
import torch
from torch import Tensor
from transformers import PreTrainedModel
from audio_diffusion_pytorch import DiffusionVocoder, UNetV0, VDiffusion, VSampler
from .config import VocoderConfig
class Vocoder(PreTrainedModel):
config_class = VocoderConfig
def __init__(self, config: VocoderConfig):
super().__init__(config)
self.model = DiffusionVocoder(
net_t=UNetV0,
mel_channels=80,
mel_n_fft=1024,
mel_sample_rate=48000,
mel_normalize_log=True,
channels=[8, 32, 64, 256, 256, 512, 512, 1024, 1024],
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2],
items=[1, 2, 2, 2, 2, 2, 2, 4, 4],
diffusion_t=VDiffusion,
sampler_t=VSampler
)
def to_spectrogram(self, *args, **kwargs):
return self.model.to_spectrogram(*args, **kwargs)
@torch.no_grad()
def sample(self, *args, **kwargs):
return self.model.sample(*args, **kwargs)