ProDiff / vocoders /base_vocoder.py
Rongjiehuang's picture
init
64e7f2f
raw
history blame
841 Bytes
import importlib
VOCODERS = {}
def register_vocoder(cls):
VOCODERS[cls.__name__.lower()] = cls
VOCODERS[cls.__name__] = cls
return cls
def get_vocoder_cls(hparams):
if hparams['vocoder'] in VOCODERS:
return VOCODERS[hparams['vocoder']]
else:
vocoder_cls = hparams['vocoder']
pkg = ".".join(vocoder_cls.split(".")[:-1])
cls_name = vocoder_cls.split(".")[-1]
vocoder_cls = getattr(importlib.import_module(pkg), cls_name)
return vocoder_cls
class BaseVocoder:
def spec2wav(self, mel):
"""
:param mel: [T, 80]
:return: wav: [T']
"""
raise NotImplementedError
@staticmethod
def wav2spec(wav_fn):
"""
:param wav_fn: str
:return: wav, mel: [T, 80]
"""
raise NotImplementedError