import importlib import re from coqpit import Coqpit def to_camel(text): text = text.capitalize() return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) def setup_model(config: Coqpit): """Load models directly from configuration.""" if "discriminator_model" in config and "generator_model" in config: MyModel = importlib.import_module("TTS.vocoder.models.gan") MyModel = getattr(MyModel, "GAN") else: MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower()) if config.model.lower() == "wavernn": MyModel = getattr(MyModel, "Wavernn") elif config.model.lower() == "gan": MyModel = getattr(MyModel, "GAN") elif config.model.lower() == "wavegrad": MyModel = getattr(MyModel, "Wavegrad") else: try: MyModel = getattr(MyModel, to_camel(config.model)) except ModuleNotFoundError as e: raise ValueError(f"Model {config.model} not exist!") from e print(" > Vocoder Model: {}".format(config.model)) return MyModel.init_from_config(config) def setup_generator(c): """TODO: use config object as arguments""" print(" > Generator Model: {}".format(c.generator_model)) MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) MyModel = getattr(MyModel, to_camel(c.generator_model)) # this is to preserve the Wavernn class name (instead of Wavernn) if c.generator_model.lower() in "hifigan_generator": model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params) elif c.generator_model.lower() in "melgan_generator": model = MyModel( in_channels=c.audio["num_mels"], out_channels=1, proj_kernel=7, base_channels=512, upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, num_res_blocks=c.generator_model_params["num_res_blocks"], ) elif c.generator_model in "melgan_fb_generator": raise ValueError("melgan_fb_generator is now fullband_melgan_generator") elif c.generator_model.lower() in "multiband_melgan_generator": model = MyModel( in_channels=c.audio["num_mels"], out_channels=4, proj_kernel=7, base_channels=384, upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, num_res_blocks=c.generator_model_params["num_res_blocks"], ) elif c.generator_model.lower() in "fullband_melgan_generator": model = MyModel( in_channels=c.audio["num_mels"], out_channels=1, proj_kernel=7, base_channels=512, upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, num_res_blocks=c.generator_model_params["num_res_blocks"], ) elif c.generator_model.lower() in "parallel_wavegan_generator": model = MyModel( in_channels=1, out_channels=1, kernel_size=3, num_res_blocks=c.generator_model_params["num_res_blocks"], stacks=c.generator_model_params["stacks"], res_channels=64, gate_channels=128, skip_channels=64, aux_channels=c.audio["num_mels"], dropout=0.0, bias=True, use_weight_norm=True, upsample_factors=c.generator_model_params["upsample_factors"], ) elif c.generator_model.lower() in "univnet_generator": model = MyModel(**c.generator_model_params) else: raise NotImplementedError(f"Model {c.generator_model} not implemented!") return model def setup_discriminator(c): """TODO: use config objekt as arguments""" print(" > Discriminator Model: {}".format(c.discriminator_model)) if "parallel_wavegan" in c.discriminator_model: MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") else: MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower()) MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) if c.discriminator_model in "hifigan_discriminator": model = MyModel() if c.discriminator_model in "random_window_discriminator": model = MyModel( cond_channels=c.audio["num_mels"], hop_length=c.audio["hop_length"], uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"], cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"], cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"], window_sizes=c.discriminator_model_params["window_sizes"], ) if c.discriminator_model in "melgan_multiscale_discriminator": model = MyModel( in_channels=1, out_channels=1, kernel_sizes=(5, 3), base_channels=c.discriminator_model_params["base_channels"], max_channels=c.discriminator_model_params["max_channels"], downsample_factors=c.discriminator_model_params["downsample_factors"], ) if c.discriminator_model == "residual_parallel_wavegan_discriminator": model = MyModel( in_channels=1, out_channels=1, kernel_size=3, num_layers=c.discriminator_model_params["num_layers"], stacks=c.discriminator_model_params["stacks"], res_channels=64, gate_channels=128, skip_channels=64, dropout=0.0, bias=True, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, ) if c.discriminator_model == "parallel_wavegan_discriminator": model = MyModel( in_channels=1, out_channels=1, kernel_size=3, num_layers=c.discriminator_model_params["num_layers"], conv_channels=64, dilation_factor=1, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, bias=True, ) if c.discriminator_model == "univnet_discriminator": model = MyModel() return model