Spaces:
Runtime error
Runtime error
File size: 3,898 Bytes
ad48e75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import argparse
import torch.nn as nn
# from icefall.utils import AttributeDict, str2bool
from .macros import (
NUM_AUDIO_TOKENS,
NUM_MEL_BINS,
NUM_SPEAKER_CLASSES,
NUM_TEXT_TOKENS,
SPEAKER_EMBEDDING_DIM,
)
from .transformer import Transformer
from .vallex import VALLE, VALLF
from .visualizer import visualize
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--model-name",
type=str,
default="VALL-E",
help="VALL-E, VALL-F, Transformer.",
)
parser.add_argument(
"--decoder-dim",
type=int,
default=1024,
help="Embedding dimension in the decoder model.",
)
parser.add_argument(
"--nhead",
type=int,
default=16,
help="Number of attention heads in the Decoder layers.",
)
parser.add_argument(
"--num-decoder-layers",
type=int,
default=12,
help="Number of Decoder layers.",
)
parser.add_argument(
"--scale-factor",
type=float,
default=1.0,
help="Model scale factor which will be assigned different meanings in different models.",
)
parser.add_argument(
"--norm-first",
type=bool,
default=True,
help="Pre or Post Normalization.",
)
parser.add_argument(
"--add-prenet",
type=bool,
default=False,
help="Whether add PreNet after Inputs.",
)
# VALL-E & F
parser.add_argument(
"--prefix-mode",
type=int,
default=1,
help="The mode for how to prefix VALL-E NAR Decoder, "
"0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
)
parser.add_argument(
"--share-embedding",
type=bool,
default=True,
help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
)
parser.add_argument(
"--prepend-bos",
type=bool,
default=False,
help="Whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs.",
)
parser.add_argument(
"--num-quantizers",
type=int,
default=8,
help="Number of Audio/Semantic quantization layers.",
)
# Transformer
parser.add_argument(
"--scaling-xformers",
type=bool,
default=False,
help="Apply Reworked Conformer scaling on Transformers.",
)
def get_model(params) -> nn.Module:
if params.model_name.lower() in ["vall-f", "vallf"]:
model = VALLF(
params.decoder_dim,
params.nhead,
params.num_decoder_layers,
norm_first=params.norm_first,
add_prenet=params.add_prenet,
prefix_mode=params.prefix_mode,
share_embedding=params.share_embedding,
nar_scale_factor=params.scale_factor,
prepend_bos=params.prepend_bos,
num_quantizers=params.num_quantizers,
)
elif params.model_name.lower() in ["vall-e", "valle"]:
model = VALLE(
params.decoder_dim,
params.nhead,
params.num_decoder_layers,
norm_first=params.norm_first,
add_prenet=params.add_prenet,
prefix_mode=params.prefix_mode,
share_embedding=params.share_embedding,
nar_scale_factor=params.scale_factor,
prepend_bos=params.prepend_bos,
num_quantizers=params.num_quantizers,
)
else:
assert params.model_name in ["Transformer"]
model = Transformer(
params.decoder_dim,
params.nhead,
params.num_decoder_layers,
norm_first=params.norm_first,
add_prenet=params.add_prenet,
scaling_xformers=params.scaling_xformers,
)
return model
|