File size: 35,775 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "17b3cbf8",
   "metadata": {},
   "source": [
    "# Context-Biasing for ASR models with CTC-based Word Spotter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1156d1d1",
   "metadata": {},
   "source": [
    "This tutorial aims to show how to improve the recognition accuracy of specific words in NeMo Framework\n",
    "for CTC and Trasducer (RNN-T) ASR models by using the fast context-biasing method with CTC-based Word Spotter.\n",
    "\n",
    "## Tutorial content:\n",
    "* Intro in the context-biasing problem\n",
    "* Description of the proposed CTC-based Words Spotter (CTC-WS) method\n",
    "* Practical part 1 (base):\n",
    "    * Download data set and ASR models\n",
    "    * Build context-biasing list\n",
    "    * Evaluate recognition results with and without context-biasing\n",
    "    * Improve context-biasing results with alternative transcriptions\n",
    "* Practical part 2 (advanced):\n",
    "    * Visualization of context-biasing graph\n",
    "    * Running CTC-based Word Spotter only\n",
    "    * Merge greedy decoding results with spotted context-biasing candidates\n",
    "    * Results analysis\n",
    "* Summary"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "431edfbf",
   "metadata": {},
   "source": [
    "## Context-biasing: intro\n",
    "\n",
    "ASR models often struggle to recognize words that were absent or had few examples in the training data.\n",
    "This problem is especially acute due to the emergence of new names and titles in a rapidly developing world.\n",
    "The users need to be able to recognize these new words.\n",
    "Context-biasing methods attempt to solve this problem by assuming that we have a list of words and phrases (context-biasing list) in advance\n",
    "for which we want to improve recognition accuracy.\n",
    "\n",
    "One of the directions of context-biasing methods is based on the `deep fusion` approach.\n",
    "These methods require intervention into the ASR model and its training process.\n",
    "The main disadvantage of these methods is that they require a lot of computational resources and time to train the model.\n",
    "\n",
    "Another direction is methods based on the `shallow fusion` approach. In this case, only the decoding process is modified.\n",
    "During the beam-search decoding, the hypothesis is rescored depending on whether the current word is present in the context-biasing list or external language model.\n",
    "The beam-search decoding may be computationally expensive, especially for the models with a large vocabulary and context-biasing list.\n",
    "This problem is considerably worsened in the case of the Transducer (RNN-T) model since beam-search decoding involves multiple Decoder (Prediction) and Joint networks calculations.\n",
    "Moreover, the context-biasing recognition is limited by the model prediction pool biased toward training data. In the case of rare or new words, the model may not have a hypothesis for the desired word from the context-biasing list whose probability we want to amplify."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae0bfd60",
   "metadata": {},
   "source": [
    "## CTC-based Word Spotter\n",
    "\n",
    "\n",
    "This tutorial considers a fast context-biasing method using a CTC-based Word Spotter (CTC-WS).\n",
    "The method involves decoding CTC log probabilities with a context graph built for words and phrases from the context-biasing list.\n",
    "The spotted context-biasing candidates (with their scores and time intervals) are compared by scores with words from the greedy\n",
    "CTC decoding results to improve recognition accuracy and pretend false accepts of context-biasing (Figure 1).  \n",
    "  \n",
    "  \n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7e45bf2",
   "metadata": {},
   "source": [
    "<figure markdown>\n",
    "  <img src=\"https://github.com/NVIDIA/NeMo/releases/download/v1.22.0/asset-post-v1.22.0-ctcws_scheme_2.png\" alt=\"CTC-WS\" style=\"width: 60%;\" height=\"auto\"> <!-- Adjust the width as needed -->\n",
    "  <figcaption><b>Figure 1.</b> <i> High-level representation of the proposed context-biasing method with CTC-WS in case of CTC model. Detected words (gpu, nvidia, cuda) are compared with words from the greedy CTC results in the overlapping intervals according to the accumulated scores to prevent false accept replacement. </i></figcaption>\n",
    "</figure>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba163f41",
   "metadata": {},
   "source": [
    "\n",
    "<!-- <img width=\"500px\" height=\"auto\"\n",
    "     src=\"https://github.com/NVIDIA/NeMo/releases/download/v1.22.0/asset-post-v1.22.0-ctcws_scheme_2.png\"\n",
    "     alt=\"CTC-WS2\"\n",
    "     style=\"float: right; margin-left: 20px;\"> -->\n",
    "     \n",
    "A [Hybrid Transducer-CTC](https://arxiv.org/abs/2312.17279) model (a shared encoder trained together with CTC and Transducer output heads) enables the use of the CTC-WS method for the Transducer model.\n",
    "Context-biasing candidates obtained by CTC-WS are also filtered by the scores with greedy CTC predictions and then merged with greedy Transducer results.\n",
    "\n",
    "The CTC-WS method allows using pretrained NeMo models (`CTC` or `Hybrid Transducer-CTC`) for context-biasing recognition without model retraining (Figure 2).\n",
    "The method shows inspired results for context-biasing with only a little additional work time and computational resources.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c05b16d8",
   "metadata": {},
   "source": [
    "<figure markdown>\n",
    "  <img src=\"https://github.com/NVIDIA/NeMo/releases/download/v1.22.0/asset-post-v1.22.0-ctcws_scheme_1.png\" alt=\"CTC-WS\" style=\"width: 65%;\" align=\"center\"> <!-- Adjust the width as needed -->\n",
    "  <figcaption><b>Figure 2.</b> <i> Scheme of the context-biasing method with CTC-based Word Spotter. CTC-WS uses CTC log probabilities to detect context-biasing candidates. Obtained candidates are filtered by CTC word alignment and then merged with CTC or RNN-T word alignment to get the final text result. </i></figcaption>\n",
    "</figure>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac0ec822",
   "metadata": {},
   "source": [
    "# Installing dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c86a4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "BRANCH = 'main'\n",
    "\n",
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\"\n",
    "\n",
    "import os\n",
    "# either provide a path to local NeMo repository with NeMo already installed or git clone\n",
    "\n",
    "# option #1: local path to NeMo repo with NeMo already installed\n",
    "NEMO_DIR_PATH = os.path.dirname(os.path.dirname(os.path.abspath('')))\n",
    "\n",
    "# check if Google Colab is being used\n",
    "try:\n",
    "    import google.colab\n",
    "    IN_COLAB = True\n",
    "except (ImportError, ModuleNotFoundError):\n",
    "    IN_COLAB = False\n",
    "\n",
    "# option #2: download NeMo repo\n",
    "if IN_COLAB or not os.path.exists(os.path.join(NEMO_DIR_PATH, \"nemo\")):\n",
    "    ## Install dependencies\n",
    "    !apt-get install sox libsndfile1 ffmpeg\n",
    "\n",
    "    !git clone -b $BRANCH https://github.com/NVIDIA/NeMo\n",
    "    %cd NeMo\n",
    "    !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n",
    "    NEMO_DIR_PATH = os.path.abspath('')\n",
    "\n",
    "import sys\n",
    "sys.path.insert(0, NEMO_DIR_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5260d4fa",
   "metadata": {},
   "source": [
    "## Practical part 1 (base)\n",
    "In this part, we will consider the base usage of the CTC-WS method for pretrained NeMo models.\n",
    "\n",
    "### Data preparation.\n",
    "We will use a subset of the GTC data set. The data set contains 10 audio files with NVIDIA GTC talks. \n",
    "The primary data set feature is the computer science and engineering domain, which has a large number of unique terms and product names (NVIDIA, GPU, GeForce, Ray Tracing, Omniverse, teraflops, etc.), which is good fit for the context-biasing task. All the text data is normalized and lowercased."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "637f2c6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# download data\n",
    "!wget https://asr-tutorial-data.s3.eu-north-1.amazonaws.com/context_biasing_data.gz\n",
    "!tar -xvzf context_biasing_data.gz\n",
    "!apt-get update && apt-get upgrade -y && apt-get install tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6baefc80",
   "metadata": {},
   "outputs": [],
   "source": [
    "!tree context_biasing_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09fe748b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest\n",
    "\n",
    "# data is already stored in nemo data manifest format\n",
    "test_nemo_manifest = \"./context_biasing_data/gtc_data_subset_10f.json\"\n",
    "test_data = read_manifest(test_nemo_manifest)\n",
    "\n",
    "for idx, item in enumerate(test_data):\n",
    "    print(f\"[{idx}]: {item['text']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64ab4764",
   "metadata": {},
   "outputs": [],
   "source": [
    "import librosa\n",
    "import IPython.display as ipd\n",
    "\n",
    "# load and listen to the audio file example\n",
    "example_file = test_data[0]['audio_filepath']\n",
    "audio, sample_rate = librosa.load(example_file)\n",
    "\n",
    "file_id = 0\n",
    "print(f\"[TEXT {file_id}]: {test_data[file_id]['text']}\\n\")\n",
    "ipd.Audio(example_file, rate=sample_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a85ea8ec",
   "metadata": {},
   "source": [
    "### Load ASR models\n",
    "\n",
    "For testing the CTC-WS method, we will use the following NeMo models:\n",
    " - (CTC): [stt_en_fastconformer_ctc_large](https://huggingface.co/nvidia/stt_en_fastconformer_ctc_large) - a large fast-conformer model trained on English ASR data\n",
    " - (Hybrid Transducer-CTC): [stt_en_fastconformer_hybrid_large_streaming_multi](https://huggingface.co/nvidia/stt_en_fastconformer_hybrid_large_streaming_multi) - a large fast-conformer model trained jointly with CTC and Transducer heads on English ASR data. The model is streaming, which means it can process audio in real time. It can cause a slight WER degradation in comparison with the first offline model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d34ee0ba",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from nemo.collections.asr.models import EncDecCTCModelBPE, EncDecHybridRNNTCTCBPEModel\n",
    "\n",
    "# ctc model\n",
    "ctc_model_name = \"stt_en_fastconformer_ctc_large\"\n",
    "ctc_model = EncDecCTCModelBPE.from_pretrained(model_name=ctc_model_name)\n",
    "\n",
    "# hybrid transducer-ctc model\n",
    "hybrid_ctc_rnnt_model_name = \"stt_en_fastconformer_hybrid_large_streaming_multi\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "082208cd",
   "metadata": {},
   "source": [
    "### Transcribe \n",
    "Let's transcribe test data and analyze the regontion accuracy of specific words "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74436885",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "test_audio_files = [item['audio_filepath'] for item in test_data]\n",
    "recog_results = ctc_model.transcribe(test_audio_files)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b993d650",
   "metadata": {},
   "source": [
    "### Compute per-word recognition statisctic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70f5714b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import texterrors\n",
    "\n",
    "word_dict = {} # {word: [num_of_occurances, num_of_correct_recognition]}\n",
    "eps = \"<eps>\"\n",
    "ref_text = [item['text'] for item in test_data]\n",
    "\n",
    "for idx, ref in enumerate(ref_text):\n",
    "    ref = ref.split()\n",
    "    hyp = recog_results[idx].text.split()\n",
    "    texterrors_ali = texterrors.align_texts(ref, hyp, False)\n",
    "    ali = []\n",
    "    for i in range(len(texterrors_ali[0])):\n",
    "        ali.append((texterrors_ali[0][i], texterrors_ali[1][i]))\n",
    "\n",
    "    for pair in ali:\n",
    "        word_ref, word_hyp = pair\n",
    "        if word_ref == eps:\n",
    "            continue\n",
    "        if word_ref in word_dict:\n",
    "            word_dict[word_ref][0] += 1\n",
    "        else:\n",
    "            word_dict[word_ref] = [1, 0]\n",
    "        if word_ref == word_hyp:\n",
    "            word_dict[word_ref][1] += 1\n",
    "\n",
    "word_candidats = {}\n",
    "\n",
    "for word in word_dict:\n",
    "    gt = word_dict[word][0]\n",
    "    tp = word_dict[word][1]\n",
    "    if tp/gt < 1.0:\n",
    "        word_candidats[word] = [gt, round(tp/gt, 2)]\n",
    "        \n",
    "# print obtained per-word statistic\n",
    "word_candidats_sorted = sorted(word_candidats.items(), key=lambda x:x[1][0], reverse=True)\n",
    "max_word_len = max([len(x[0]) for x in word_candidats_sorted])\n",
    "for item in word_candidats_sorted:\n",
    "    print(f\"{item[0]:<{max_word_len}} {item[1][0]}/{item[1][1]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27a9f88b",
   "metadata": {},
   "source": [
    "## Create a context-biasing list\n",
    "\n",
    "Now, we need to select the words, recognition of which we want to improve by CTC-WS context-biasing.\n",
    "Usually, we select only nontrivial words with the lowest recognition accuracy.\n",
    "Such words should have a character length >= 3 because short words in a context-biasing list may produce high false-positive recognition.\n",
    "In this toy example, we will select all the words that look like names with a recognition accuracy less than 1.0.\n",
    "\n",
    "The structure of the context-biasing file is:\n",
    "\n",
    "WORD1_TRANSCRIPTION1  \n",
    "WORD2_TRANSCRIPTION1   \n",
    "...\n",
    "\n",
    "TRANSCRIPTION here is a word spelling. We need this structure to add alternative transcriptions (spellings) for some word. We will cover such a case further."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c27848f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "cb_words = [\"gpu\", \"nvidia\", \"nvidia's\", \"nvlink\", \"omniverse\", \"cunumeric\", \"numpy\", \"dgx\", \"dgxs\", \"dlss\",\n",
    "            \"cpu\", \"tsmc\", \"culitho\", \"xlabs\", \"tensorrt\", \"tensorflow\", \"pytorch\", \"aws\", \"chatgpt\", \"pcie\"]\n",
    "\n",
    "# write context-biasing file \n",
    "cb_list_file = \"context_biasing_data/context_biasing_list.txt\"\n",
    "with open(cb_list_file, \"w\", encoding=\"utf-8\") as fn:\n",
    "    for word in cb_words:\n",
    "        fn.write(f\"{word}_{word}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0e8c800",
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat {cb_list_file}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c44fc910",
   "metadata": {},
   "source": [
    "## Run context-biasing evaluation\n",
    "\n",
    "The main script for CTC-WS context-biasing in NeMo is:\\\n",
    "`{NEMO_DIR_PATH}/scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py`\n",
    "\n",
    "Context-biasing is managed by `apply_context_biasing` parameter [true or false].  \n",
    "Other important context-biasing parameters are:\n",
    "- `beam_threshold` - threshold for CTC-WS beam pruning\n",
    "- `context_score` - per token weight for context biasing\n",
    "- `ctc_ali_token_weight` - per token weight for CTC alignment (prevents false acceptances of context-biasing words) \n",
    "\n",
    "All the context-biasing parameters are selected according to the default values in the script.  \n",
    "You can tune them according to your data and ASR model (list all the values in the [] separated by commas)  \n",
    "for example: `beam_threshold=[7.0,8.0,9.0]`, `context_score=[3.0,4.0,5.0]`, `ctc_ali_token_weight=[0.5,0.6,0.7]`.  \n",
    "The script will run the recognition with all the combinations of the parameters and will select the best one based on WER value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a2d32e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create directory with experimental results\n",
    "import os\n",
    "\n",
    "exp_dir = \"exp\"\n",
    "if not os.path.isdir(exp_dir):\n",
    "    os.makedirs(exp_dir)\n",
    "else:\n",
    "    print(f\"Directory '{exp_dir}' already exists\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "116f2abe",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ctc model (no context-biasing)\n",
    "\n",
    "!python {NEMO_DIR_PATH}/scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py \\\n",
    "            nemo_model_file={ctc_model_name} \\\n",
    "            input_manifest={test_nemo_manifest} \\\n",
    "            preds_output_folder={exp_dir} \\\n",
    "            decoder_type=\"ctc\" \\\n",
    "            acoustic_batch_size=64 \\\n",
    "            apply_context_biasing=false \\\n",
    "            context_file={cb_list_file} \\\n",
    "            beam_threshold=[7.0] \\\n",
    "            context_score=[3.0] \\\n",
    "            ctc_ali_token_weight=[0.5]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "674d0af1",
   "metadata": {},
   "source": [
    "The results must be:\n",
    "\n",
    "`Precision`: 1.0000 (1/1) fp:0 (fp - false positive recognition)  \n",
    "`Recall`:    0.0333 (1/30)  \n",
    "`Fscore`:    0.0645  \n",
    "`Greedy WER/CER` = 35.68%/8.16%\n",
    "\n",
    "The model could recognize 1 out of 30 words from the context-biasing list.\n",
    "Let's enable context-biasing during decoding:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "239da41d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ctc model (with context biasing)\n",
    "!python {NEMO_DIR_PATH}/scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py \\\n",
    "            nemo_model_file={ctc_model_name} \\\n",
    "            input_manifest={test_nemo_manifest} \\\n",
    "            preds_output_folder={exp_dir} \\\n",
    "            decoder_type=\"ctc\" \\\n",
    "            acoustic_batch_size=64 \\\n",
    "            apply_context_biasing=true \\\n",
    "            context_file={cb_list_file} \\\n",
    "            beam_threshold=[7.0] \\\n",
    "            context_score=[3.0] \\\n",
    "            ctc_ali_token_weight=[0.5]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "faa1e73c",
   "metadata": {},
   "source": [
    "Now, recognition results are much better:\n",
    "\n",
    "`Precision`: 1.0000 (21/21) fp:0  \n",
    "`Recall`:    0.7000 (21/30)  \n",
    "`Fscore`:    0.8235  \n",
    "`Greedy WER/CER` = 17.09%/4.43%\n",
    "\n",
    "But we are still able to recognize only 21 out of 30 specific words.\\\n",
    "You can see that unrecognized words are mostly abbreviations (`dgxs`, `dlss`, `gpu`, `aws`, etc.) or compound words (`culitho`).\\\n",
    "The ASR models tends to recognize such words as a sequence of characters (`\"aws\" -> \"a w s\"`) or subwords (`\"culitho\" -> \"cu litho\"`).\\\n",
    "We can try to improve the recognition of such words by adding alternative transcriptions to the context-biasing list."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d72b6391",
   "metadata": {},
   "source": [
    "### Alternative transcriptions\n",
    "\n",
    "wordninja is used to split compound words into simple words according to the default word dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7e00263",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install wordninja"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46fe91e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import wordninja\n",
    "\n",
    "cb_list_file_modified = cb_list_file + \".abbr_and_ninja\"\n",
    "\n",
    "with open(cb_list_file, \"r\", encoding=\"utf-8\") as fn1, \\\n",
    "    open(cb_list_file_modified, \"w\", encoding=\"utf-8\") as fn2:\n",
    "\n",
    "    for line in fn1:\n",
    "        word = line.strip().split(\"_\")[0]\n",
    "        new_line = f\"{word}_{word}\"\n",
    "        # split all the short words into characters\n",
    "        if len(word) <= 4 and len(word.split()) == 1:\n",
    "            abbr = ' '.join(list(word))\n",
    "            new_line += f\"_{abbr}\"\n",
    "        # split the long words into the simple words (not for phrases)\n",
    "        new_segmentation = wordninja.split(word)\n",
    "        if word != new_segmentation[0]:\n",
    "            new_segmentation = ' '.join(new_segmentation)\n",
    "            new_line += f\"_{new_segmentation}\"\n",
    "        fn2.write(f\"{new_line}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da69da45",
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat {cb_list_file_modified}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a21cbf4",
   "metadata": {},
   "source": [
    "Run context-biasing with modified context-biasing list:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913a0f5e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ctc models (with context biasing and modified cb list)\n",
    "!python {NEMO_DIR_PATH}/scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py \\\n",
    "            nemo_model_file={ctc_model_name} \\\n",
    "            input_manifest={test_nemo_manifest} \\\n",
    "            preds_output_folder={exp_dir} \\\n",
    "            decoder_type=\"ctc\" \\\n",
    "            acoustic_batch_size=64 \\\n",
    "            apply_context_biasing=true \\\n",
    "            context_file={cb_list_file_modified} \\\n",
    "            beam_threshold=[7.0] \\\n",
    "            context_score=[3.0] \\\n",
    "            ctc_ali_token_weight=[0.5]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "654751ed",
   "metadata": {},
   "source": [
    "Now, the recognition results are:\n",
    "\n",
    "`Precision`: 1.0000 (28/28) fp:1  \n",
    "`Recall`:    0.9333 (28/30)  \n",
    "`Fscore`:    0.9655  \n",
    "`Greedy WER/CER` = 7.04%/2.93%\n",
    "\n",
    "As you can see, that adding alternative transcriptions to the cb_list file improved the recognition accuracy of the context-biasing words. However, we still miss 2 words. Unfortunately, this algorithm is not a silver bullet.\n",
    "\n",
    "In some cases, you can improve results by adding alternative transcriptions manually based on the recognition errors of your ASR model for the specific words (for example, `\"nvidia\" -> \"n video\"`). "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b96c4023",
   "metadata": {},
   "source": [
    "### Hybrid Transducer-CTC model\n",
    "The CTC-WS context-biasing method for Transducer (RNN-T) models is supported only for Hybrid Transducer-CTC model.  \n",
    "To use Transducer head of the Hybrid Transducer-CTC model, we need to set `decoder_type=\"rnnt\"`.  \n",
    "Other parameters are the same as for the CTC model because the context-biasing is applied only on the CTC part of the model. Spotted context-biasing words will have been merged with greedy decoding results of the Transducer head.\n",
    "\n",
    "We can use already prepared context-biasing list because the CTC and Hybrid Transducer-CTC models have almost the same BPE tokenizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "456e47df",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Transducer model (no context-biasing)\n",
    "!python {NEMO_DIR_PATH}/scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py \\\n",
    "            nemo_model_file={hybrid_ctc_rnnt_model_name} \\\n",
    "            input_manifest={test_nemo_manifest} \\\n",
    "            preds_output_folder={exp_dir} \\\n",
    "            decoder_type=\"rnnt\" \\\n",
    "            acoustic_batch_size=64 \\\n",
    "            apply_context_biasing=false \\\n",
    "            context_file={cb_list_file_modified} \\\n",
    "            beam_threshold=[7.0] \\\n",
    "            context_score=[3.0] \\\n",
    "            ctc_ali_token_weight=[0.5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "773e11f1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Transducer model (with context-biasing)\n",
    "!python {NEMO_DIR_PATH}/scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py \\\n",
    "            nemo_model_file={hybrid_ctc_rnnt_model_name} \\\n",
    "            input_manifest={test_nemo_manifest} \\\n",
    "            preds_output_folder={exp_dir} \\\n",
    "            decoder_type=\"rnnt\" \\\n",
    "            acoustic_batch_size=64 \\\n",
    "            apply_context_biasing=true \\\n",
    "            context_file={cb_list_file_modified} \\\n",
    "            beam_threshold=[7.0] \\\n",
    "            context_score=[3.0] \\\n",
    "            ctc_ali_token_weight=[0.5]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45a91385",
   "metadata": {},
   "source": [
    "CTC-WS context-biasing works for Transducer model as well as for CTC (`F-score improvenment: 0.3784 -> 0.9286`). Differences in the nature of offline and online models may cause differences in results (usually, online models have a tendency to predict tokens earlier what can affect the difference between the timestamps of CTC and RNN-T models). "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1968e7bc",
   "metadata": {},
   "source": [
    "## Practical part 2 (advanced)\n",
    "In this section, we will consider the context-biasing process more deeply:\n",
    "- Visualization of the context-biasing graph\n",
    "- Running CTC-WS with the context-biasing graph\n",
    "- Merge the obtained spotted words with greedy decoding results\n",
    "- Analysis of the results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "277104b5",
   "metadata": {},
   "source": [
    "### Build a context graph (for visualization only)\n",
    "The context graph consists of a composition of a prefix tree (Trie) with the CTC transition topology for words and phrases from the context-biasing list. We use a BPE tokenizer from the target ASR model for word segmentation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55a36a27-919c-4d64-9163-b0b2c9dca15e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# install graphviz from source in case of local run (not Google Colab)\n",
    "# this may take about 5-10 minutes\n",
    "# make sure that env variables have been set\n",
    "\n",
    "if not IN_COLAB:\n",
    "\n",
    "    os.environ['DEBIAN_FRONTEND'] = 'noninteractive'\n",
    "    os.environ['TZ'] = 'Etc/UTC'\n",
    "\n",
    "    !echo $DEBIAN_FRONTEND\n",
    "    !echo $TZ\n",
    "\n",
    "    !{NEMO_DIR_PATH}/scripts/installers/install_graphviz.sh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "904ea41b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.parts import context_biasing\n",
    "\n",
    "# get bpe tokenization\n",
    "cb_words_small = [\"nvidia\", \"gpu\", \"nvlink\", \"numpy\"]\n",
    "context_transcripts = []\n",
    "for word in cb_words_small:\n",
    "    # use text_to_tokens method for viasualization only\n",
    "    word_tokenization = ctc_model.tokenizer.text_to_tokens(word)\n",
    "    print(f\"{word}: {word_tokenization}\")\n",
    "    context_transcripts.append([word, [word_tokenization]])\n",
    "\n",
    "# build context graph\n",
    "context_graph = context_biasing.ContextGraphCTC(blank_id=\"⊘\")\n",
    "context_graph.add_to_graph(context_transcripts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7fab1e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "context_graph.draw()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04a6f4be",
   "metadata": {},
   "source": [
    "### Build a real context graph (for decoding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ba2d8a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get bpe tokenization\n",
    "context_transcripts = []\n",
    "for word in cb_words:\n",
    "    word_tokenization = [ctc_model.tokenizer.text_to_ids(x) for x in word]\n",
    "    context_transcripts.append([word, word_tokenization])\n",
    "\n",
    "# build context graph\n",
    "context_graph = context_biasing.ContextGraphCTC(blank_id=ctc_model.decoding.blank_id)\n",
    "context_graph.add_to_graph(context_transcripts)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71e0e86b",
   "metadata": {},
   "source": [
    "### Run CTC-based Word Spotter\n",
    "\n",
    "The CTC-WS task is to search for words by decoding CTC log probabilities using the context graph. As a result, we obtain a list of detected words with exact start/end frames in the audio file and their overall scores. The relatively small size of the context graph and hypotheses pruning methods allow this algorithm to work very quickly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c2c942e-e8df-4c09-a7de-87606ae453c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install --upgrade jupyter ipywidgets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2bc370b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "# get ctc logprobs\n",
    "audio_file_paths = [item['audio_filepath'] for item in test_data]\n",
    "\n",
    "with torch.no_grad():\n",
    "    ctc_model.eval()\n",
    "    ctc_model.encoder.freeze()\n",
    "    device = next(ctc_model.parameters()).device\n",
    "    hyp_results = ctc_model.transcribe(audio_file_paths, batch_size=10, return_hypotheses=True)\n",
    "    ctc_logprobs = [hyp.alignments.cpu().numpy() for hyp in hyp_results]\n",
    "    blank_idx = ctc_model.decoding.blank_id\n",
    "    \n",
    "# run ctc-based word spotter\n",
    "ws_results = {}\n",
    "for idx, logits in tqdm(\n",
    "    enumerate(ctc_logprobs), desc=f\"Eval CTC-based Word Spotter...\", total=len(ctc_logprobs)\n",
    "):\n",
    "    ws_results[audio_file_paths[idx]] = context_biasing.run_word_spotter(\n",
    "        logits,\n",
    "        context_graph,\n",
    "        ctc_model,\n",
    "        blank_idx=blank_idx,\n",
    "        beam_threshold=7.0,\n",
    "        cb_weight=3.0,\n",
    "        ctc_ali_token_weight=0.5,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bd6645c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print CTC-WS hypotheses for the first audio file\n",
    "ws_results[audio_file_paths[0]]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "245a66f0",
   "metadata": {},
   "source": [
    "### Merge CTC-WS words with greedy CTC decoding results\n",
    "\n",
    "Use `print_stats=True` to get more information about spotted words and greedy CTC word alignment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "423b2b9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "target_transcripts = [item['text'] for item in test_data]\n",
    "\n",
    "# merge spotted words with greedy results\n",
    "for idx, logprobs in enumerate(ctc_logprobs):\n",
    "    greedy_predicts = np.argmax(logprobs, axis=1)\n",
    "    if ws_results[audio_file_paths[idx]]:\n",
    "        # make new text by mearging alignment with ctc-ws predictions:\n",
    "        print(\"\\n\" + \"********\" * 10)\n",
    "        print(f\"File name: {audio_file_paths[idx]}\")\n",
    "        pred_text, raw_text = context_biasing.merge_alignment_with_ws_hyps(\n",
    "            greedy_predicts,\n",
    "            ctc_model,\n",
    "            ws_results[audio_file_paths[idx]],\n",
    "            decoder_type=\"ctc\",\n",
    "            blank_idx=blank_idx,\n",
    "            print_stats=True,\n",
    "        )\n",
    "        print(f\"[raw text]: {raw_text}\")\n",
    "        print(f\"[hyp text]: {pred_text}\")\n",
    "        print(f\"[ref text]: {target_transcripts[idx]}\")\n",
    "    else:\n",
    "        # if no spotted words, use standard greedy predictions\n",
    "        pred_text = ctc_model.wer.decoding.ctc_decoder_predictions_tensor(greedy_predicts)[0].text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb8b5f51",
   "metadata": {},
   "source": [
    "In these logs, you can find detailed context-biasing statistics about each audio file:\n",
    "- Audio file name\n",
    "- Greedy word alignment\n",
    "- List of spotted words\n",
    "- Text results:\n",
    "    - Greedy decoding (raw text)\n",
    "    - Text after applying context-biasing (hyp text)\n",
    "    - Ground truth transcription (ref text)\n",
    "    \n",
    "These statistics can be helpful in case of problems with context-biasing word recognition. For example, Transducer models sometimes recognize tokens 1-2 frames earlier than CTC models. To solve this problem, you can shift the start frame of the detected word left in the CTC-WS sources."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11220db2",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "This tutorial demonstrates how to use the CTC-WS context-biasing technique to improve the recognition accuracy of specific words in the case of CTC and Transducer (RNN-T) ASR models. The tutorial includes the methodology for creating the context-biasing list, improving recognition accuracy of abbreviations and compound words, visualization of the context-biasing process, and results analysis.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}