File size: 30,721 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a3570803-9bfa-4e97-9891-5ae0759eb8ca",
   "metadata": {},
   "source": [
    "# Hybrid ASR-TTS Models Tutorial"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50fc294f-f319-4465-8f90-a28b49843e60",
   "metadata": {},
   "source": [
    "This tutorial is intended to introduce you to using ASR-TTS Hybrid Models, also known as `ASRWithTTSModel`, to finetune existing ASR models using an integrated text-to-mel-spectrogram generator. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2a01ca5-bd48-4d82-a97d-5b07a7b27ca0",
   "metadata": {},
   "source": [
    "## ASR-TTS Models: Description"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b32467a9-c458-4590-aff7-e8d1e91b0870",
   "metadata": {},
   "source": [
    "### Problem\n",
    "\n",
    "Adapting ASR models to a new text domain is a challenging task. Modern end-to-end systems can require several hundreds and thousands of hours to perform recognition with high accuracy. Acquiring audio-text paired data for a specific domain can be prohibitively expensive. Text-only data, on the other side, is widely available. \n",
    "\n",
    "One of the approaches for efficient adaptation is synthesizing audio data from text and using such data for training the ASR model conventionally. We modify this approach, incorporating TTS and ASR systems into a single model. We use only a lightweight multi-speaker text-to-mel-spectrogram generator (without vocoder) with an optional enhancer that mitigates the mismatch between natural and synthetic spectrograms.\n",
    "\n",
    "### Architecture\n",
    "\n",
    "<img width=\"400px\" height=\"auto\"\n",
    "     src=\"https://github.com/NVIDIA/NeMo/blob/stable/docs/source/asr/images/hybrid_asr_tts_model.png?raw=true\"\n",
    "     alt=\"ASR-TTS model architecture\"\n",
    "     style=\"float: right; margin-left: 20px;\">\n",
    "\n",
    "`ASRWithTTSModel` is a transparent wrapper for three models:\n",
    "- ASR model (`EncDecCTCModelBPE`, `EncDecRNNTBPEModel` or `EncDecHybridRNNTCTCBPEModel` are supported)\n",
    "- frozen text-to-mel-spectrogram model (currently, only `FastPitch` model is supported)\n",
    "- optional frozen enhancer model\n",
    "\n",
    "The architecture is shown in the figure. \n",
    "\n",
    "The model can take text or audio as input during training. In the case of audio input, a mel spectrogram is extracted as usual and passed to the ASR neural network. In the case of textual input, the mel spectrogram generator produces a spectrogram on the fly from the text. The spectrogram is improved by the enhancer (if present) and fed into the ASR model. \n",
    "\n",
    "### Capabilities and Limitations\n",
    "\n",
    "This approach can be used to finetune the pretrained ASR model using text-only data. Training new models from scratch is also possible. The text should contain phrases and sentences and be split into sentences (~45 words maximum, corresponding to ~16.7 seconds of synthesized audio). Using only separate words is not recommended since this doesn't allow to adapt ASR model adapts to recognize new words in context. \n",
    "\n",
    "Mixing audio-text pairs with text-only data from the original domain is recommended to preserve performance on the original data. \n",
    "Also, fusing BatchNorm (see parameters below) is recommended for the best performance when using a large proportion of text compared to the amount of audio-text pairs in finetuning process.\n",
    "\n",
    "\n",
    "### Implementation Details and Experiments\n",
    "\n",
    "Further details about implementation and experiments can be found in the paper [Text-only domain adaptation for end-to-end ASR using integrated text-to-mel-spectrogram generator](https://arxiv.org/abs/2302.14036)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2702d081-c675-4a96-8263-6059e310d048",
   "metadata": {},
   "source": [
    "## Example: Finetuning ASR Model Using Text-Only Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30fe41a3-f36c-4803-a7f0-4260fb111478",
   "metadata": {},
   "source": [
    "In this example, we will finetune a pretrained small Conformer-CTC model using text-only data from the AN4 dataset. [AN4 dataset](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/datasets.html#an4-dataset) is a small dataset that consists of sentences of people spelling out addresses, names, and other entities.\n",
    "\n",
    "The model is pretrained on LibriSpeech data and performs poorly on AN4 data (`~17.7%` WER on test data).\n",
    "We will use only text from the train part to construct text-only training data for our model and will achieve a good performance on the test part of the AN4 dataset (`~2%` WER)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "923819bb-7822-412a-8f9b-98c76c70e0bb",
   "metadata": {},
   "source": [
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run the following cell to set up dependencies.\n",
    "\n",
    "NOTE: The user is responsible for checking the content of datasets and the applicable licenses and determining if they are suitable for the intended use."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4685a9da-b3f8-4b95-ba74-64a114223233",
   "metadata": {},
   "source": [
    "### Install Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d22d241-6c46-492c-99db-3bd69777243c",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import google.colab\n",
    "\n",
    "    IN_COLAB = True\n",
    "except (ImportError, ModuleNotFoundError):\n",
    "    IN_COLAB = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc38a961-8822-4685-89ae-ab6f591f9c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "BRANCH = 'main'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd60b1c4-7b1d-421d-9d63-95d7458bbcbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you're using Google Colab and not running locally, run this cell.\n",
    "\n",
    "if IN_COLAB:\n",
    "    ## Install dependencies\n",
    "    !pip install wget\n",
    "    !apt-get install sox libsndfile1 ffmpeg\n",
    "    !pip install text-unidecode\n",
    "\n",
    "    ## Install NeMo\n",
    "    !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08f99618-6f83-44b3-bc8e-f7df04fc471c",
   "metadata": {},
   "source": [
    "### Import necessary libraries and utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74f780b1-9b72-4acf-bcf0-64e1ce84e76d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "import string\n",
    "import tempfile\n",
    "\n",
    "from omegaconf import OmegaConf\n",
    "import lightning.pytorch as pl\n",
    "import torch\n",
    "from tqdm.auto import tqdm\n",
    "import wget\n",
    "\n",
    "from nemo.collections.asr.models import EncDecCTCModelBPE\n",
    "from nemo.collections.asr.models.hybrid_asr_tts_models import ASRWithTTSModel\n",
    "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest\n",
    "from nemo.collections.tts.models import FastPitchModel, SpectrogramEnhancerModel\n",
    "from nemo.utils.notebook_utils import download_an4\n",
    "\n",
    "try:\n",
    "    from nemo_text_processing.text_normalization.normalize import Normalizer\n",
    "except ModuleNotFoundError:\n",
    "    raise ModuleNotFoundError(\n",
    "        \"The package `nemo_text_processing` was not installed in this environment. Please refer to\"\n",
    "        \" https://github.com/NVIDIA/NeMo-text-processing and install this package before using \"\n",
    "        \"this script\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca928d36-fb0d-439b-bac0-299e98a72d02",
   "metadata": {},
   "source": [
    "### Prepare Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "702e8e92-17b2-4f34-a2d9-c72b94501bf5",
   "metadata": {},
   "source": [
    "Download and preprocess AN4 data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62c7cfec-aa98-4fc5-8b31-23ee1d59f311",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASETS_DIR = Path(\"./datasets\")  # directory for data\n",
    "CHECKPOINTS_DIR = Path(\"./checkpoints/\")  # directory for checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "659db73e-dcd7-455c-8140-20e104d6ac00",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create directories if necessary\n",
    "DATASETS_DIR.mkdir(parents=True, exist_ok=True)\n",
    "CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36830e7f-5293-4401-8c56-780127b47385",
   "metadata": {},
   "outputs": [],
   "source": [
    "download_an4(data_dir=f\"{DATASETS_DIR}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e77f5062-9acb-4f39-b811-a5b11dd6f76f",
   "metadata": {},
   "outputs": [],
   "source": [
    "AN4_DATASET = DATASETS_DIR / \"an4\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "403b63b0-8aab-43aa-a455-31f588d1772f",
   "metadata": {},
   "source": [
    "### Construct text-only training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35654ee1-3869-4289-bd52-15818c0ccf69",
   "metadata": {},
   "outputs": [],
   "source": [
    "# read original training data\n",
    "an4_train_data = read_manifest(AN4_DATASET / \"train_manifest.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a17f583c-2a5c-4faf-84bd-eb04c2921e01",
   "metadata": {},
   "source": [
    "Text-only manifest should contain three fields:\n",
    "- `text`: target text for the ASR model\n",
    "- `tts_text`: text to use as a source for the TTS model (unnormalized)\n",
    "- `tts_text_normalized`: text to use as a source for TTS model (normalized)\n",
    "\n",
    "If `tts_text_normalized` is not present, `tts_text` will be used, and normalization will be done when loading the dataset.\n",
    "It is highly recommended to normalize the text and manually create the `tts_text_normalized` field since current normalizers are unsuitable for processing a large amount of text on the fly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5938a8c2-e239-4a45-a716-dc11a981aec7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fill `text` and `tts_text` fields with the source data\n",
    "textonly_data = []\n",
    "for record in an4_train_data:\n",
    "    text = record[\"text\"]\n",
    "    textonly_data.append({\"text\": text, \"tts_text\": text})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f6a5735-a5c2-4a8b-8116-bfc535a2c299",
   "metadata": {},
   "outputs": [],
   "source": [
    "WHITELIST_URL = (\n",
    "    \"https://raw.githubusercontent.com/NVIDIA/NeMo-text-processing/main/\"\n",
    "    \"nemo_text_processing/text_normalization/en/data/whitelist/lj_speech.tsv\"\n",
    ")\n",
    "\n",
    "\n",
    "def get_normalizer() -> Normalizer:\n",
    "    with tempfile.TemporaryDirectory() as data_dir:\n",
    "        whitelist_path = Path(data_dir) / \"lj_speech.tsv\"\n",
    "        if not whitelist_path.exists():\n",
    "            wget.download(WHITELIST_URL, out=str(data_dir))\n",
    "\n",
    "        normalizer = Normalizer(\n",
    "            lang=\"en\",\n",
    "            input_case=\"cased\",\n",
    "            whitelist=str(whitelist_path),\n",
    "            overwrite_cache=True,\n",
    "            cache_dir=None,\n",
    "        )\n",
    "    return normalizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd0253aa-d7f1-47ee-a142-099b71241270",
   "metadata": {},
   "source": [
    "Сonstruct the `tts_text_normalized` field by applying an English normalizer to the text.\n",
    "\n",
    "AN4 data doesn't contain numbers, currency, and other entities, so the normalizer is used here only for demonstration purposes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27bb29d5-d44d-4026-98f8-5f0b1241b39a",
   "metadata": {},
   "outputs": [],
   "source": [
    "normalizer = get_normalizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9400e6d3-ba92-442a-8dd4-117e95dce2ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "for record in tqdm(textonly_data):\n",
    "    record[\"tts_text_normalized\"] = normalizer.normalize(\n",
    "        record[\"tts_text\"], verbose=False, punct_pre_process=True, punct_post_process=True\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30a934b0-9b58-4bad-bb9a-ab78d81c3859",
   "metadata": {},
   "source": [
    "Save manifest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1833ac15-1750-4468-88bc-2343fbabe4d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "write_manifest(AN4_DATASET / \"train_text_manifest.json\", textonly_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa3a2371-8c78-4dd1-9605-a668adf52b4a",
   "metadata": {},
   "source": [
    "### Save pretrained checkpoints"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7eb14117-8b8b-4170-ab8c-ce496522a361",
   "metadata": {},
   "source": [
    "Firstly we will load pretrained models from NGC and save them as `nemo` checkpoints. \n",
    "Our hybrid model will be constructed from these checkpoints.\n",
    "We will use:\n",
    "- small Conformer-CTC ASR model trained on LibriSpeech data (for finetuning)\n",
    "- multi-speaker TTS FastPitch model is trained on LibriTTS data. Spectrogram parameters for this model are the same as those used in the ASR model\n",
    "- enhancer, which is trained adversarially on the output of the TTS model and natural spectrograms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43c5c75a-b6e0-4b3c-ad26-a07b483d84e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "ASR_MODEL_PATH = CHECKPOINTS_DIR / \"stt_en_conformer_ctc_small_ls.nemo\"\n",
    "TTS_MODEL_PATH = CHECKPOINTS_DIR / \"fastpitch.nemo\"\n",
    "ENHANCER_MODEL_PATH = CHECKPOINTS_DIR / \"enhancer.nemo\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40976e22-7a7b-42b2-86a1-9eaaef4c1c22",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# asr model: stt_en_conformer_ctc_small_ls\n",
    "asr_model = EncDecCTCModelBPE.from_pretrained(model_name=\"stt_en_conformer_ctc_small_ls\")\n",
    "asr_model.save_to(f\"{ASR_MODEL_PATH}\")\n",
    "\n",
    "# tts model: tts_en_fastpitch_for_asr_finetuning\n",
    "tts_model = FastPitchModel.from_pretrained(model_name=\"tts_en_fastpitch_for_asr_finetuning\")\n",
    "tts_model.save_to(f\"{TTS_MODEL_PATH}\")\n",
    "\n",
    "# enhancer model: tts_en_spectrogram_enhancer_for_asr_finetuning\n",
    "enhancer_model = SpectrogramEnhancerModel.from_pretrained(model_name=\"tts_en_spectrogram_enhancer_for_asr_finetuning\")\n",
    "enhancer_model.save_to(f\"{ENHANCER_MODEL_PATH}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32d1e242-0ab0-43bf-aaa0-997d284c2c1b",
   "metadata": {},
   "source": [
    "### Construct hybrid ASR-TTS model "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2210eb07-6d44-44e0-a0ad-866f1e89873a",
   "metadata": {},
   "source": [
    "#### Config Parameters\n",
    "\n",
    "`Hybrid ASR-TTS model` consists of three parts:\n",
    "\n",
    "* ASR model (``EncDecCTCModelBPE``, ``EncDecRNNTBPEModel`` or ``EncDecHybridRNNTCTCBPEModel``)\n",
    "* TTS Mel Spectrogram Generator (currently, only `FastPitch` model is supported)\n",
    "* Enhancer model (optional)\n",
    "\n",
    "Also, the config allows to specify a text-only dataset.\n",
    "\n",
    "Main parts of the config:\n",
    "\n",
    "* ASR model\n",
    "    * ``asr_model_path``: path to the ASR model checkpoint (`.nemo`) file, loaded only once, then the config of the ASR model is stored in the ``asr_model`` field\n",
    "    * ``asr_model_type``: needed only when training from scratch. ``rnnt_bpe`` corresponds to ``EncDecRNNTBPEModel``, ``ctc_bpe`` to ``EncDecCTCModelBPE``, ``hybrid_rnnt_ctc_bpe`` to ``EncDecHybridRNNTCTCBPEModel``\n",
    "    * ``asr_model_fuse_bn``: fusing BatchNorm in the pretrained ASR model, can improve quality in finetuning scenario\n",
    "* TTS model\n",
    "    * ``tts_model_path``: path to the pretrained TTS model checkpoint (`.nemo`) file, loaded only once, then the config of the model is stored in the ``tts_model`` field\n",
    "* Enhancer model\n",
    "    * ``enhancer_model_path``: optional path to the enhancer model. Loaded only once, the config is stored in the ``enhancer_model`` field\n",
    "* ``train_ds``\n",
    "    * ``text_data``: properties related to text-only data\n",
    "        * ``manifest_filepath``: path (or paths) to text-only dataset manifests\n",
    "        * ``speakers_filepath``: path (or paths) to the text file containing speaker ids for the multi-speaker TTS model (speakers are sampled randomly during training)\n",
    "        * ``min_words`` and ``max_words``: parameters to filter text-only manifests by the number of words\n",
    "        * ``tokenizer_workers``: number of workers for initial tokenization (when loading the data). ``num_CPUs / num_GPUs`` is a recommended value.\n",
    "    * ``asr_tts_sampling_technique``, ``asr_tts_sampling_temperature``, ``asr_tts_sampling_probabilities``: sampling parameters for text-only and audio-text data (if both specified). Correspond to ``sampling_technique``, ``sampling_temperature``, and ``sampling_probabilities`` parameters of the `nemo.collections.common.data.dataset.ConcatDataset`.\n",
    "    * all other components are similar to conventional ASR models\n",
    "* ``validation_ds`` and ``test_ds`` correspond to the underlying ASR model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d6dd499-d388-4ee3-9a01-d739b16e6ad7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load config\n",
    "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/asr_tts/hybrid_asr_tts.yaml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6701dc8-cb3b-44cc-aab5-fb6e2c1dadb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = OmegaConf.load(\"./configs/hybrid_asr_tts.yaml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c13b3c96-4074-415f-95d2-17569886bfcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_EPOCHS = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d090c5d-44a7-401a-a753-b8779b1c1e0b",
   "metadata": {},
   "source": [
    "We will use all available speakers (sampled uniformly)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c41e5e8-d926-4b83-8725-bae5a82121cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "TTS_SPEAKERS_PATH = Path(\"./checkpoints/speakers.txt\")\n",
    "\n",
    "with open(TTS_SPEAKERS_PATH, \"w\", encoding=\"utf-8\") as f:\n",
    "    for speaker_id in range(tts_model.cfg.n_speakers):\n",
    "        print(speaker_id, file=f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c07c07c-cb15-4a1c-80bf-20eaffaa65d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "config.model.asr_model_path = ASR_MODEL_PATH\n",
    "config.model.tts_model_path = TTS_MODEL_PATH\n",
    "config.model.enhancer_model_path = ENHANCER_MODEL_PATH\n",
    "\n",
    "# fuse BathNorm automatically in Conformer for better performance\n",
    "config.model.asr_model_fuse_bn = True\n",
    "\n",
    "# training data\n",
    "# constructed dataset\n",
    "config.model.train_ds.text_data.manifest_filepath = str(AN4_DATASET / \"train_text_manifest.json\")\n",
    "# speakers for TTS model\n",
    "config.model.train_ds.text_data.speakers_filepath = f\"{TTS_SPEAKERS_PATH}\"\n",
    "config.model.train_ds.manifest_filepath = None  # audio-text pairs - we don't use them here\n",
    "config.model.train_ds.batch_size = 8\n",
    "\n",
    "# validation data\n",
    "config.model.validation_ds.manifest_filepath = str(AN4_DATASET / \"test_manifest.json\")\n",
    "config.model.validation_ds.batch_size = 8\n",
    "\n",
    "config.trainer.max_epochs = NUM_EPOCHS\n",
    "\n",
    "config.trainer.devices = 1\n",
    "config.trainer.strategy = 'auto'  # use 1 device, no need for ddp strategy\n",
    "\n",
    "OmegaConf.resolve(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ae6cb2e-f571-4b53-8897-bb8ba0fc1146",
   "metadata": {},
   "source": [
    "#### Construct trainer and ASRWithTTSModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac4ae885-dec4-4ce9-8f69-a1f35d04b08c",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(**config.trainer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f815762-b08d-4d3c-8fd3-61afa511eab4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "hybrid_model = ASRWithTTSModel(config.model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca2c1bf2-28a9-4902-9c73-d96e04b21a46",
   "metadata": {},
   "source": [
    "#### Validate the model\n",
    "\n",
    "Expect `~17.7%` WER on the AN4 test data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffa5f5c6-0609-4f46-aa0c-747319035417",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.validate(hybrid_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "701ee9c7-91a1-4917-bf7d-ab26b625c7bf",
   "metadata": {},
   "source": [
    "#### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f79761c9-b882-4f14-911f-4a960ff81554",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer.fit(hybrid_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eac18c7c-bdcb-40ad-9c50-37f89fb4aa2a",
   "metadata": {},
   "source": [
    "#### Validate the model after training\n",
    "\n",
    "Expect `~2%` WER on the AN4 test data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd927e87-13fb-4b61-8b4a-a6850780f605",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer.validate(hybrid_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d25a77d-35ed-44b5-9ef5-318afa321acf",
   "metadata": {},
   "source": [
    "### Save final model. Extract pure ASR model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f53ebd3-b89a-47e4-a0a5-ed3a3572f7c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save full model: the model can be further used for finetuning\n",
    "hybrid_model.save_to(\"checkpoints/finetuned_hybrid_model.nemo\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0560c2c-af28-4d8f-b36d-c18ec6a482a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# extract the resulting ASR model from the hybrid model\n",
    "hybrid_model.save_asr_model_to(\"checkpoints/finetuned_asr_model.nemo\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2de58fbb-50be-42cd-9095-01cacfdb6931",
   "metadata": {},
   "source": [
    "## Using Scripts (examples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86655198-b1fc-4615-958c-7c01f3cbd024",
   "metadata": {},
   "source": [
    "`<NeMo_git_root>/examples/asr/asr_with_tts/` contains scripts for finetuning existing models and training new models from scratch."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5837536-8280-475c-a581-caaee00edfca",
   "metadata": {},
   "source": [
    "### Finetuning Existing Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84df9aeb-3b5e-41fc-a8d0-dfc660e71375",
   "metadata": {},
   "source": [
    "To finetune existing ASR model using text-only data use `<NeMo_git_root>/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py` script with the corresponding config `<NeMo_git_root>/examples/asr/conf/asr_tts/hybrid_asr_tts.yaml`.\n",
    "\n",
    "Please specify paths to all the required models (ASR, TTS, and Enhancer checkpoints), along with `train_ds.text_data.manifest_filepath` and `train_ds.text_data.speakers_filepath`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78b9028c-02ce-4af4-b510-a431f4a2f62b",
   "metadata": {},
   "source": [
    "```shell\n",
    "python speech_to_text_bpe_with_text_finetune.py \\\n",
    "    model.asr_model_path=<path to ASR model> \\\n",
    "    model.tts_model_path=<path to compatible TTS model> \\\n",
    "    model.enhancer_model_path=<optional path to enhancer model> \\\n",
    "    model.asr_model_fuse_bn=<true recommended if ConformerEncoder with BatchNorm, false otherwise> \\\n",
    "    model.train_ds.manifest_filepath=<path to manifest with audio-text pairs or null> \\\n",
    "    model.train_ds.text_data.manifest_filepath=<path(s) to manifest with train text> \\\n",
    "    model.train_ds.text_data.speakers_filepath=<path(s) to speakers list> \\\n",
    "    model.train_ds.text_data.tokenizer_workers=4 \\\n",
    "    model.validation_ds.manifest_filepath=<path to validation manifest> \\\n",
    "    model.train_ds.batch_size=<batch_size>\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b17c097-a3b1-49a3-8f54-f07b94218d0b",
   "metadata": {},
   "source": [
    "### Training a New Model from Scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d75b928-57b3-4180-bd09-37e018eef7ef",
   "metadata": {},
   "source": [
    "```shell\n",
    "python speech_to_text_bpe_with_text.py \\\n",
    "    # (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \\\n",
    "    ++asr_model_type=<rnnt_bpe, ctc_bpe or hybrid_rnnt_ctc_bpe> \\\n",
    "    ++tts_model_path=<path to compatible tts model> \\\n",
    "    ++enhancer_model_path=<optional path to enhancer model> \\\n",
    "    model.tokenizer.dir=<path to tokenizer> \\\n",
    "    model.tokenizer.type=\"bpe\" \\\n",
    "    model.train_ds.manifest_filepath=<path(s) to manifest with audio-text pairs or null> \\\n",
    "    ++model.train_ds.text_data.manifest_filepath=<path(s) to manifests with train text> \\\n",
    "    ++model.train_ds.text_data.speakers_filepath=<path(s) to speakers list> \\\n",
    "    ++model.train_ds.text_data.min_words=1 \\\n",
    "    ++model.train_ds.text_data.max_words=45 \\\n",
    "    ++model.train_ds.text_data.tokenizer_workers=4 \\\n",
    "    model.validation_ds.manifest_filepath=<path(s) to val/test manifest> \\\n",
    "    model.train_ds.batch_size=<batch size> \\\n",
    "    trainer.max_epochs=<num epochs> \\\n",
    "    trainer.num_nodes=<number of nodes> \\\n",
    "    trainer.accumulate_grad_batches=<grad accumultion> \\\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01c17712-ae8d-49cb-ade1-ded168676e27",
   "metadata": {},
   "source": [
    "## Training TTS Models for ASR Finetuning"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "422dc3b2-d29f-4ed0-b4d2-6d32b35dfb7b",
   "metadata": {},
   "source": [
    "### TTS Model (FastPitch)\n",
    "\n",
    "TTS model for the purpose of ASR model finetuning should be trained with the same mel spectrogram parameters as used in the ASR model. The typical parameters are `10ms` hop length, `25ms` window length, and the highest band of 8kHz (for 16kHz data). Other parameters are the same as for common multi-speaker TTS models.\n",
    "\n",
    "Mainly we observed two differences specific to TTS models for ASR:\n",
    "- adding more speakers and more data improves the final ASR model quality (but not the perceptual quality of the TTS model)\n",
    "- training for more epochs can also improve the quality of the ASR system (but MSE loss used for the TTS model can be higher than optimal on validation data)\n",
    "\n",
    "Use script `<NeMo_git_root>/examples/tts/fastpitch.py` to train a FastPitch model.\n",
    "More details about the FastPitch model can be found in the [documentation](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/tts/models.html#fastpitch). \n",
    "\n",
    "### Enhancer\n",
    "Use script `<NeMo_git_root>/examples/tts/spectrogram_enhancer.py` to train an Enhancer model. More details can be found in the \n",
    "[documentation](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/tts/models.html).\n",
    "\n",
    "### Models Used in This Tutorial\n",
    "\n",
    "Some details about the models used in this tutorial can be found on [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/tts_en_fastpitch_spectrogram_enhancer_for_asr_finetuning).\n",
    "\n",
    "The system is also described in detail in the paper in the paper [Text-only domain adaptation for end-to-end ASR using integrated text-to-mel-spectrogram generator](https://arxiv.org/abs/2302.14036)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a9a6cd3-4bdc-4b6e-b4b1-3bfd50fd01b3",
   "metadata": {},
   "source": [
    "## Summary"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2890c61-e4b7-47aa-a086-bc483ae7141f",
   "metadata": {},
   "source": [
    "The tutorial demonstrated the main concepts related to hybrid ASR-TTS models to finetune ASR models and train new ones from scratch. \n",
    "The ability to achieve good text-only adaptation results is demonstrated by finetuning a small Conformer model on text-only data from the AN4 dataset."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml38",
   "language": "python",
   "name": "ml38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}