Spaces:
Starting
on
L4
Starting
on
L4
import argparse | |
import torch.nn as nn | |
# from icefall.utils import AttributeDict, str2bool | |
from .macros import ( | |
NUM_AUDIO_TOKENS, | |
NUM_MEL_BINS, | |
NUM_SPEAKER_CLASSES, | |
NUM_TEXT_TOKENS, | |
SPEAKER_EMBEDDING_DIM, | |
) | |
from .vallex import VALLE, VALLF | |
def add_model_arguments(parser: argparse.ArgumentParser): | |
parser.add_argument( | |
"--model-name", | |
type=str, | |
default="VALL-E", | |
help="VALL-E, VALL-F, Transformer.", | |
) | |
parser.add_argument( | |
"--decoder-dim", | |
type=int, | |
default=1024, | |
help="Embedding dimension in the decoder model.", | |
) | |
parser.add_argument( | |
"--nhead", | |
type=int, | |
default=16, | |
help="Number of attention heads in the Decoder layers.", | |
) | |
parser.add_argument( | |
"--num-decoder-layers", | |
type=int, | |
default=12, | |
help="Number of Decoder layers.", | |
) | |
parser.add_argument( | |
"--scale-factor", | |
type=float, | |
default=1.0, | |
help="Model scale factor which will be assigned different meanings in different models.", | |
) | |
parser.add_argument( | |
"--norm-first", | |
type=bool, | |
default=True, | |
help="Pre or Post Normalization.", | |
) | |
parser.add_argument( | |
"--add-prenet", | |
type=bool, | |
default=False, | |
help="Whether add PreNet after Inputs.", | |
) | |
# VALL-E & F | |
parser.add_argument( | |
"--prefix-mode", | |
type=int, | |
default=1, | |
help="The mode for how to prefix VALL-E NAR Decoder, " | |
"0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.", | |
) | |
parser.add_argument( | |
"--share-embedding", | |
type=bool, | |
default=True, | |
help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.", | |
) | |
parser.add_argument( | |
"--prepend-bos", | |
type=bool, | |
default=False, | |
help="Whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs.", | |
) | |
parser.add_argument( | |
"--num-quantizers", | |
type=int, | |
default=8, | |
help="Number of Audio/Semantic quantization layers.", | |
) | |
# Transformer | |
parser.add_argument( | |
"--scaling-xformers", | |
type=bool, | |
default=False, | |
help="Apply Reworked Conformer scaling on Transformers.", | |
) | |
def get_model(params) -> nn.Module: | |
if params.model_name.lower() in ["vall-f", "vallf"]: | |
model = VALLF( | |
params.decoder_dim, | |
params.nhead, | |
params.num_decoder_layers, | |
norm_first=params.norm_first, | |
add_prenet=params.add_prenet, | |
prefix_mode=params.prefix_mode, | |
share_embedding=params.share_embedding, | |
nar_scale_factor=params.scale_factor, | |
prepend_bos=params.prepend_bos, | |
num_quantizers=params.num_quantizers, | |
) | |
elif params.model_name.lower() in ["vall-e", "valle"]: | |
model = VALLE( | |
params.decoder_dim, | |
params.nhead, | |
params.num_decoder_layers, | |
norm_first=params.norm_first, | |
add_prenet=params.add_prenet, | |
prefix_mode=params.prefix_mode, | |
share_embedding=params.share_embedding, | |
nar_scale_factor=params.scale_factor, | |
prepend_bos=params.prepend_bos, | |
num_quantizers=params.num_quantizers, | |
) | |
else: | |
raise ValueError("No such model") | |
return model | |