|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import asdict |
|
|
|
import pytorch_lightning as pl |
|
|
|
from nemo.collections.asr.models import EncDecCTCModel, configs |
|
from nemo.core.config import modelPT, optimizers, schedulers |
|
from nemo.utils.exp_manager import exp_manager |
|
|
|
""" |
|
python speech_to_text_structured_v2.py |
|
""" |
|
|
|
|
|
LABELS = [ |
|
" ", "a", "b", "c", "d", "e", |
|
"f", "g", "h", "i", "j", "k", |
|
"l", "m", "n", "o", "p", "q", |
|
"r", "s", "t", "u", "v", "w", |
|
"x", "y", "z", "'", |
|
] |
|
|
|
optim_cfg = optimizers.NovogradParams( |
|
lr=0.01, |
|
betas=(0.8, 0.5), |
|
weight_decay=0.001 |
|
) |
|
|
|
sched_cfg = schedulers.CosineAnnealingParams( |
|
warmup_steps=None, |
|
warmup_ratio=None, |
|
min_lr=0.0, |
|
) |
|
|
|
|
|
|
|
def main(): |
|
|
|
cfg = modelPT.NemoConfig(name='Custom QuartzNet') |
|
|
|
|
|
builder = configs.EncDecCTCModelConfigBuilder(name='quartznet_15x5') |
|
|
|
|
|
builder.set_labels(LABELS) |
|
builder.set_optim(cfg=optim_cfg, sched_cfg=sched_cfg) |
|
|
|
model_cfg = builder.build() |
|
|
|
|
|
cfg.model = model_cfg |
|
|
|
|
|
|
|
|
|
model_cfg.train_ds.manifest_filepath = "" |
|
|
|
|
|
model_cfg.validation_ds.manifest_filepath = "" |
|
|
|
|
|
cfg.trainer.devices = 1 |
|
cfg.trainer.max_epochs = 5 |
|
|
|
|
|
cfg.exp_manager.name = cfg.name |
|
|
|
|
|
trainer = pl.Trainer(**asdict(cfg.trainer)) |
|
exp_manager(trainer, asdict(cfg.exp_manager)) |
|
asr_model = EncDecCTCModel(cfg=cfg.model, trainer=trainer) |
|
|
|
trainer.fit(asr_model) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|