File size: 1,850 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b81d26
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from dataclasses import asdict, dataclass, field
from typing import Dict, List

from coqpit import MISSING

from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig


@dataclass
class BaseEncoderConfig(BaseTrainingConfig):
    """Defines parameters for a Generic Encoder model."""

    model: str = None
    audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
    datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
    # model params
    model_params: Dict = field(
        default_factory=lambda: {
            "model_name": "lstm",
            "input_dim": 80,
            "proj_dim": 256,
            "lstm_dim": 768,
            "num_lstm_layers": 3,
            "use_lstm_with_projection": True,
        }
    )

    audio_augmentation: Dict = field(default_factory=lambda: {})

    # training params
    epochs: int = 5000
    loss: str = "angleproto"
    grad_clip: float = 3.0
    lr: float = 0.0001
    optimizer: str = "radam"
    optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0})
    lr_decay: bool = False
    warmup_steps: int = 4000

    # logging params
    tb_model_param_stats: bool = False
    steps_plot_stats: int = 10
    save_step: int = 1000
    print_step: int = 20
    run_eval: bool = False

    # data loader
    num_classes_in_batch: int = MISSING
    num_utter_per_class: int = MISSING
    eval_num_classes_in_batch: int = None
    eval_num_utter_per_class: int = None

    num_loader_workers: int = MISSING
    voice_len: float = 1.6

    def check_values(self):
        super().check_values()
        c = asdict(self)
        assert (
            c["model_params"]["input_dim"] == self.audio.num_mels
        ), " [!] model input dimendion must be equal to melspectrogram dimension."