rristo commited on
Commit
46455cd
1 Parent(s): d51f5bc

initial commit

Browse files
Files changed (36) hide show
  1. .gitattributes +1 -0
  2. .gitignore +2 -0
  3. README.md +65 -0
  4. build.bat +1 -0
  5. build.sh +1 -0
  6. docker/Dockerfile +188 -0
  7. err2020/audio/emt16k.wav +0 -0
  8. err2020/conformer_ctc3/__init__.py +0 -0
  9. err2020/conformer_ctc3/__pycache__/__init__.cpython-39.pyc +0 -0
  10. err2020/conformer_ctc3/__pycache__/asr_datamodule.cpython-39.pyc +0 -0
  11. err2020/conformer_ctc3/__pycache__/conformer.cpython-39.pyc +0 -0
  12. err2020/conformer_ctc3/__pycache__/decode.cpython-39.pyc +0 -0
  13. err2020/conformer_ctc3/__pycache__/encoder_interface.cpython-39.pyc +0 -0
  14. err2020/conformer_ctc3/__pycache__/model.cpython-39.pyc +0 -0
  15. err2020/conformer_ctc3/__pycache__/optim.cpython-39.pyc +0 -0
  16. err2020/conformer_ctc3/__pycache__/scaling.cpython-39.pyc +0 -0
  17. err2020/conformer_ctc3/__pycache__/train.cpython-39.pyc +0 -0
  18. err2020/conformer_ctc3/asr_datamodule.py +458 -0
  19. err2020/conformer_ctc3/conformer.py +1598 -0
  20. err2020/conformer_ctc3/decode.py +1052 -0
  21. err2020/conformer_ctc3/encoder_interface.py +43 -0
  22. err2020/conformer_ctc3/exp/jit_trace.pt +3 -0
  23. err2020/conformer_ctc3/export.py +292 -0
  24. err2020/conformer_ctc3/jit_pretrained.py +413 -0
  25. err2020/conformer_ctc3/lstmp.py +102 -0
  26. err2020/conformer_ctc3/model.py +122 -0
  27. err2020/conformer_ctc3/optim.py +320 -0
  28. err2020/conformer_ctc3/pretrained.py +461 -0
  29. err2020/conformer_ctc3/scaling.py +1015 -0
  30. err2020/conformer_ctc3/test_model.py +82 -0
  31. err2020/conformer_ctc3/train.py +1109 -0
  32. err2020/conformer_ctc3_usage.ipynb +500 -0
  33. err2020/data/lang_bpe_500/bpe.model +3 -0
  34. requirements.txt +1 -0
  35. run.bat +9 -0
  36. run.sh +11 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ err2020/audio/oden_kypsis16k.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ **/.ipynb_checkpoints/
2
+ .idea/
README.md CHANGED
@@ -1,3 +1,68 @@
1
  ---
 
 
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - et
4
  license: apache-2.0
5
+ metrics:
6
+ - wer
7
+ model-index:
8
+ - name: conformer-ctc et
9
+ results:
10
+ - task:
11
+ name: Automatic Speech Recognition
12
+ type: automatic-speech-recognition
13
+ dataset:
14
+ name: ERR2020
15
+ args: et
16
+ metrics:
17
+ - name: Wer
18
+ type: wer
19
+ value: 12.1
20
  ---
21
+
22
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
23
+ should probably proofread and complete it, then remove this comment. -->
24
+
25
+ # conformer-ctc et
26
+
27
+ Icefall conformer-ctc3 based recipe (https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conformer_ctc3) trained Estonian ASR model using ERR2020 dataset
28
+ - WER on ERR2020: 12.1
29
+ - WER on mozilla commonvoice_11: 24.5
30
+
31
+
32
+ For usage:
33
+ - clone this repo (`git clone https://huggingface.co/rristo/icefall_conformer_ctc3_et`)
34
+ - go to repo (`cd icefall_conformer_ctc3_et`)
35
+ - build docker image for needed libraries (`build.sh` or `build.bat`)
36
+ - run docker container (`run.sh`or `run.sh`). This mounts current directory
37
+ - run notebook `err2020/conformer_ctc3_usage.ipynb` for example usage
38
+ - currently expects audio to be in .wav format
39
+
40
+ ## Model description
41
+
42
+ ASR model for Estonian, uses Estonian Public Broadcasting data ERR2020 data (around 230 hours of audio)
43
+
44
+ ## Intended uses & limitations
45
+
46
+ Pretty much a toy model, trained on limited amount of data. Might not work well on data out of domain
47
+ (especially spontaneous/noisy data).
48
+
49
+ ## Training and evaluation data
50
+
51
+ Trained on ERR2020 data, evaluated on ERR2020 and mozilla commonvoice test data.
52
+
53
+ ## Training procedure
54
+
55
+ Used Icefall conformer-ctc3 based recipe (https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conformer_ctc3)
56
+
57
+ ### Training results
58
+
59
+
60
+ TODO
61
+
62
+ ### Framework versions
63
+
64
+ - icefall
65
+ - k2
66
+ - kaldifeat==1.24
67
+ - lhotse==1.15.0
68
+ - torch==2.0.0
build.bat ADDED
@@ -0,0 +1 @@
 
 
1
+ docker build -t icefall -f docker/Dockerfile .
build.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ docker build -t icefall -f docker/Dockerfile .
docker/Dockerfile ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # ==================================================================
4
+ # Initial setup
5
+ # ------------------------------------------------------------------
6
+
7
+ # Ubuntu 20.04 as base image
8
+ FROM ubuntu:20.04
9
+ RUN yes| unminimize
10
+
11
+ # Set ENV variables
12
+ ENV LANG C.UTF-8
13
+ ENV SHELL=/bin/bash
14
+ ENV DEBIAN_FRONTEND=noninteractive
15
+
16
+ ENV APT_INSTALL="apt-get install -y --no-install-recommends"
17
+ ENV PIP_INSTALL="python3 -m pip --no-cache-dir install --upgrade"
18
+ ENV GIT_CLONE="git clone --depth 10"
19
+
20
+
21
+ # ==================================================================
22
+ # Tools
23
+ # ------------------------------------------------------------------
24
+
25
+ RUN apt-get update && \
26
+ $APT_INSTALL \
27
+ apt-utils \
28
+ gcc \
29
+ make \
30
+ pkg-config \
31
+ apt-transport-https \
32
+ build-essential \
33
+ ca-certificates \
34
+ wget \
35
+ rsync \
36
+ git \
37
+ vim \
38
+ mlocate \
39
+ libssl-dev \
40
+ curl \
41
+ openssh-client \
42
+ unzip \
43
+ unrar \
44
+ zip \
45
+ csvkit \
46
+ emacs \
47
+ joe \
48
+ jq \
49
+ dialog \
50
+ man-db \
51
+ manpages \
52
+ manpages-dev \
53
+ manpages-posix \
54
+ manpages-posix-dev \
55
+ nano \
56
+ iputils-ping \
57
+ sudo \
58
+ ffmpeg \
59
+ libsm6 \
60
+ libxext6 \
61
+ libboost-all-dev \
62
+ cifs-utils \
63
+ software-properties-common
64
+
65
+
66
+ #RUN curl -LO http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
67
+ #RUN bash Miniconda3-latest-Linux-x86_64.sh -p /miniconda -b
68
+ #RUN rm Miniconda3-latest-Linux-x86_64.sh
69
+ #ENV PATH=/miniconda/bin:${PATH}
70
+ #RUN conda update -y conda
71
+
72
+ ## conda
73
+ #RUN conda install -c anaconda -y python=3.9.7
74
+
75
+
76
+ # ==================================================================
77
+ # Python
78
+ # ------------------------------------------------------------------
79
+
80
+ #Based on https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa
81
+
82
+ # Adding repository for python3.9
83
+ RUN add-apt-repository ppa:deadsnakes/ppa -y && \
84
+
85
+ # Installing python3.9
86
+ $APT_INSTALL \
87
+ python3.9 \
88
+ python3.9-dev \
89
+ python3.9-venv \
90
+ python3-distutils-extra
91
+
92
+ # Add symlink so python and python3 commands use same python3.9 executable
93
+ RUN ln -s /usr/bin/python3.9 /usr/local/bin/python3 && \
94
+ ln -s /usr/bin/python3.9 /usr/local/bin/python
95
+
96
+ # Installing pip
97
+ RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.9
98
+ ENV PATH=$PATH:/root/.local/bin
99
+
100
+ RUN pip install torch==2.0.0+cpu torchvision==0.15.1+cpu torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cpu
101
+
102
+
103
+ # ==================================================================
104
+ # JupyterLab
105
+ # ------------------------------------------------------------------
106
+
107
+ # Based on https://jupyterlab.readthedocs.io/en/stable/getting_started/installation.html#pip
108
+
109
+ RUN $PIP_INSTALL jupyterlab==3.4.6
110
+
111
+ # ==================================================================
112
+ # Additional Python Packages
113
+ # ------------------------------------------------------------------
114
+
115
+ RUN $PIP_INSTALL \
116
+ numpy==1.23.4 \
117
+ scipy==1.9.2 \
118
+ pandas==1.5.0 \
119
+ cloudpickle==2.2.0 \
120
+ scikit-image==0.19.3 \
121
+ scikit-learn==1.1.2 \
122
+ matplotlib==3.6.1 \
123
+ ipython==8.5.0 \
124
+ ipykernel==6.16.0 \
125
+ ipywidgets==8.0.2 \
126
+ cython==0.29.32 \
127
+ tqdm==4.64.1 \
128
+ pillow==9.2.0 \
129
+ seaborn==0.12.0 \
130
+ future==0.18.2 \
131
+ jsonify==0.5 \
132
+ opencv-python==4.6.0.66 \
133
+ awscli==1.25.91 \
134
+ jupyterlab-snippets==0.4.1
135
+
136
+ # ==================================================================
137
+ # CMake
138
+ # ------------------------------------------------------------------
139
+
140
+ RUN git clone https://github.com/Kitware/CMake ~/cmake && \
141
+ cd ~/cmake && \
142
+ ./bootstrap && \
143
+ make -j"$(nproc)" install
144
+
145
+
146
+ # ==================================================================
147
+ # Node.js and Jupyter Notebook Extensions
148
+ # ------------------------------------------------------------------
149
+
150
+ RUN curl -sL https://deb.nodesource.com/setup_16.x | bash && \
151
+ $APT_INSTALL nodejs && \
152
+ $PIP_INSTALL jupyter_contrib_nbextensions jupyterlab-git && \
153
+ jupyter contrib nbextension install --user
154
+
155
+
156
+ # ==================================================================
157
+ # Icefall stuff
158
+ # ------------------------------------------------------------------
159
+
160
+ #k2
161
+ RUN cd /opt && \
162
+ git clone https://github.com/k2-fsa/k2.git && \
163
+ cd k2 && \
164
+ mkdir build-cpu && \
165
+ cd build-cpu && \
166
+ cmake -DK2_WITH_CUDA=OFF -DCMAKE_BUILD_TYPE=Debug .. && \
167
+ make -j5
168
+
169
+ ENV PYTHONPATH "${PYTHONPATH}:/opt/k2/build-cpu/../k2/python"
170
+ ENV PYTHONPATH "${PYTHONPATH}:/opt/k2/build-cpu/lib"
171
+
172
+ #icefall
173
+ RUN mkdir /opt/install/
174
+ COPY requirements.txt /opt/install/requirements.txt
175
+ RUN pip3 install -r /opt/install/requirements.txt
176
+ RUN cd /opt && git clone https://github.com/k2-fsa/icefall
177
+ RUN cd /opt/icefall && pip install -r requirements.txt
178
+ ENV PYTHONPATH "${PYTHONPATH}:/opt/icefall/"
179
+ RUN pip install kaldifeat
180
+ RUN mkdir /opt/notebooks
181
+
182
+ # ==================================================================
183
+ # Startup
184
+ # ------------------------------------------------------------------
185
+
186
+ EXPOSE 8888 6006
187
+ WORKDIR /opt/notebooks
188
+ CMD jupyter lab --allow-root --ip=0.0.0.0 --ServerApp.trust_xheaders=True --ServerApp.disable_check_xsrf=False --ServerApp.allow_remote_access=True --ServerApp.allow_origin='*' --ServerApp.allow_credentials=True
err2020/audio/emt16k.wav ADDED
Binary file (408 kB). View file
 
err2020/conformer_ctc3/__init__.py ADDED
File without changes
err2020/conformer_ctc3/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (140 Bytes). View file
 
err2020/conformer_ctc3/__pycache__/asr_datamodule.cpython-39.pyc ADDED
Binary file (9.95 kB). View file
 
err2020/conformer_ctc3/__pycache__/conformer.cpython-39.pyc ADDED
Binary file (43.6 kB). View file
 
err2020/conformer_ctc3/__pycache__/decode.cpython-39.pyc ADDED
Binary file (24.7 kB). View file
 
err2020/conformer_ctc3/__pycache__/encoder_interface.cpython-39.pyc ADDED
Binary file (1.34 kB). View file
 
err2020/conformer_ctc3/__pycache__/model.cpython-39.pyc ADDED
Binary file (3.65 kB). View file
 
err2020/conformer_ctc3/__pycache__/optim.cpython-39.pyc ADDED
Binary file (9.99 kB). View file
 
err2020/conformer_ctc3/__pycache__/scaling.cpython-39.pyc ADDED
Binary file (30.5 kB). View file
 
err2020/conformer_ctc3/__pycache__/train.cpython-39.pyc ADDED
Binary file (24.7 kB). View file
 
err2020/conformer_ctc3/asr_datamodule.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 Piotr Żelasko
2
+ # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ import argparse
20
+ import inspect
21
+ import logging
22
+ from functools import lru_cache
23
+ from pathlib import Path
24
+ from typing import Any, Dict, Optional
25
+
26
+ import torch
27
+ from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
28
+ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
29
+ CutConcatenate,
30
+ CutMix,
31
+ DynamicBucketingSampler,
32
+ K2SpeechRecognitionDataset,
33
+ PrecomputedFeatures,
34
+ SingleCutSampler,
35
+ SpecAugment,
36
+ )
37
+ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
38
+ AudioSamples,
39
+ OnTheFlyFeatures,
40
+ )
41
+ from lhotse.utils import fix_random_seed
42
+ from torch.utils.data import DataLoader
43
+
44
+ from icefall.utils import str2bool
45
+
46
+
47
+ class _SeedWorkers:
48
+ def __init__(self, seed: int):
49
+ self.seed = seed
50
+
51
+ def __call__(self, worker_id: int):
52
+ fix_random_seed(self.seed + worker_id)
53
+
54
+
55
+ class LibriSpeechAsrDataModule:
56
+ """
57
+ DataModule for k2 ASR experiments.
58
+ It assumes there is always one train and valid dataloader,
59
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
60
+ and test-other).
61
+
62
+ It contains all the common data pipeline modules used in ASR
63
+ experiments, e.g.:
64
+ - dynamic batch size,
65
+ - bucketing samplers,
66
+ - cut concatenation,
67
+ - augmentation,
68
+ - on-the-fly feature extraction
69
+
70
+ This class should be derived for specific corpora used in ASR tasks.
71
+ """
72
+
73
+ def __init__(self, args: argparse.Namespace):
74
+ self.args = args
75
+
76
+ @classmethod
77
+ def add_arguments(cls, parser: argparse.ArgumentParser):
78
+ group = parser.add_argument_group(
79
+ title="ASR data related options",
80
+ description="These options are used for the preparation of "
81
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
82
+ "effective batch sizes, sampling strategies, applied data "
83
+ "augmentations, etc.",
84
+ )
85
+ group.add_argument(
86
+ "--full-libri",
87
+ type=str2bool,
88
+ default=True,
89
+ help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
90
+ )
91
+ group.add_argument(
92
+ "--manifest-dir",
93
+ type=Path,
94
+ default=Path("data/fbank"),
95
+ help="Path to directory with train/valid/test cuts.",
96
+ )
97
+ group.add_argument(
98
+ "--max-duration",
99
+ type=int,
100
+ default=200.0,
101
+ help="Maximum pooled recordings duration (seconds) in a "
102
+ "single batch. You can reduce it if it causes CUDA OOM.",
103
+ )
104
+ group.add_argument(
105
+ "--bucketing-sampler",
106
+ type=str2bool,
107
+ default=True,
108
+ help="When enabled, the batches will come from buckets of "
109
+ "similar duration (saves padding frames).",
110
+ )
111
+ group.add_argument(
112
+ "--num-buckets",
113
+ type=int,
114
+ default=30,
115
+ help="The number of buckets for the DynamicBucketingSampler"
116
+ "(you might want to increase it for larger datasets).",
117
+ )
118
+ group.add_argument(
119
+ "--concatenate-cuts",
120
+ type=str2bool,
121
+ default=False,
122
+ help="When enabled, utterances (cuts) will be concatenated "
123
+ "to minimize the amount of padding.",
124
+ )
125
+ group.add_argument(
126
+ "--duration-factor",
127
+ type=float,
128
+ default=1.0,
129
+ help="Determines the maximum duration of a concatenated cut "
130
+ "relative to the duration of the longest cut in a batch.",
131
+ )
132
+ group.add_argument(
133
+ "--gap",
134
+ type=float,
135
+ default=1.0,
136
+ help="The amount of padding (in seconds) inserted between "
137
+ "concatenated cuts. This padding is filled with noise when "
138
+ "noise augmentation is used.",
139
+ )
140
+ group.add_argument(
141
+ "--on-the-fly-feats",
142
+ type=str2bool,
143
+ default=False,
144
+ help="When enabled, use on-the-fly cut mixing and feature "
145
+ "extraction. Will drop existing precomputed feature manifests "
146
+ "if available.",
147
+ )
148
+ group.add_argument(
149
+ "--shuffle",
150
+ type=str2bool,
151
+ default=True,
152
+ help="When enabled (=default), the examples will be "
153
+ "shuffled for each epoch.",
154
+ )
155
+ group.add_argument(
156
+ "--drop-last",
157
+ type=str2bool,
158
+ default=True,
159
+ help="Whether to drop last batch. Used by sampler.",
160
+ )
161
+ group.add_argument(
162
+ "--return-cuts",
163
+ type=str2bool,
164
+ default=True,
165
+ help="When enabled, each batch will have the "
166
+ "field: batch['supervisions']['cut'] with the cuts that "
167
+ "were used to construct it.",
168
+ )
169
+
170
+ group.add_argument(
171
+ "--num-workers",
172
+ type=int,
173
+ default=2,
174
+ help="The number of training dataloader workers that "
175
+ "collect the batches.",
176
+ )
177
+
178
+ group.add_argument(
179
+ "--enable-spec-aug",
180
+ type=str2bool,
181
+ default=True,
182
+ help="When enabled, use SpecAugment for training dataset.",
183
+ )
184
+
185
+ group.add_argument(
186
+ "--spec-aug-time-warp-factor",
187
+ type=int,
188
+ default=80,
189
+ help="Used only when --enable-spec-aug is True. "
190
+ "It specifies the factor for time warping in SpecAugment. "
191
+ "Larger values mean more warping. "
192
+ "A value less than 1 means to disable time warp.",
193
+ )
194
+
195
+ group.add_argument(
196
+ "--enable-musan",
197
+ type=str2bool,
198
+ default=True,
199
+ help="When enabled, select noise from MUSAN and mix it"
200
+ "with training dataset. ",
201
+ )
202
+
203
+ group.add_argument(
204
+ "--input-strategy",
205
+ type=str,
206
+ default="PrecomputedFeatures",
207
+ help="AudioSamples or PrecomputedFeatures",
208
+ )
209
+
210
+ def train_dataloaders(
211
+ self,
212
+ cuts_train: CutSet,
213
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
214
+ ) -> DataLoader:
215
+ """
216
+ Args:
217
+ cuts_train:
218
+ CutSet for training.
219
+ sampler_state_dict:
220
+ The state dict for the training sampler.
221
+ """
222
+ transforms = []
223
+ if self.args.enable_musan:
224
+ logging.info("Enable MUSAN")
225
+ logging.info("About to get Musan cuts")
226
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
227
+ transforms.append(
228
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
229
+ )
230
+ else:
231
+ logging.info("Disable MUSAN")
232
+
233
+ if self.args.concatenate_cuts:
234
+ logging.info(
235
+ f"Using cut concatenation with duration factor "
236
+ f"{self.args.duration_factor} and gap {self.args.gap}."
237
+ )
238
+ # Cut concatenation should be the first transform in the list,
239
+ # so that if we e.g. mix noise in, it will fill the gaps between
240
+ # different utterances.
241
+ transforms = [
242
+ CutConcatenate(
243
+ duration_factor=self.args.duration_factor, gap=self.args.gap
244
+ )
245
+ ] + transforms
246
+
247
+ input_transforms = []
248
+ if self.args.enable_spec_aug:
249
+ logging.info("Enable SpecAugment")
250
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
251
+ # Set the value of num_frame_masks according to Lhotse's version.
252
+ # In different Lhotse's versions, the default of num_frame_masks is
253
+ # different.
254
+ num_frame_masks = 10
255
+ num_frame_masks_parameter = inspect.signature(
256
+ SpecAugment.__init__
257
+ ).parameters["num_frame_masks"]
258
+ if num_frame_masks_parameter.default == 1:
259
+ num_frame_masks = 2
260
+ logging.info(f"Num frame mask: {num_frame_masks}")
261
+ input_transforms.append(
262
+ SpecAugment(
263
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
264
+ num_frame_masks=num_frame_masks,
265
+ features_mask_size=27,
266
+ num_feature_masks=2,
267
+ frames_mask_size=100,
268
+ )
269
+ )
270
+ else:
271
+ logging.info("Disable SpecAugment")
272
+
273
+ logging.info("About to create train dataset")
274
+ train = K2SpeechRecognitionDataset(
275
+ input_strategy=eval(self.args.input_strategy)(),
276
+ cut_transforms=transforms,
277
+ input_transforms=input_transforms,
278
+ return_cuts=self.args.return_cuts,
279
+ )
280
+
281
+ if self.args.on_the_fly_feats:
282
+ # NOTE: the PerturbSpeed transform should be added only if we
283
+ # remove it from data prep stage.
284
+ # Add on-the-fly speed perturbation; since originally it would
285
+ # have increased epoch size by 3, we will apply prob 2/3 and use
286
+ # 3x more epochs.
287
+ # Speed perturbation probably should come first before
288
+ # concatenation, but in principle the transforms order doesn't have
289
+ # to be strict (e.g. could be randomized)
290
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
291
+ # Drop feats to be on the safe side.
292
+ train = K2SpeechRecognitionDataset(
293
+ cut_transforms=transforms,
294
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
295
+ input_transforms=input_transforms,
296
+ return_cuts=self.args.return_cuts,
297
+ )
298
+
299
+ if self.args.bucketing_sampler:
300
+ logging.info("Using DynamicBucketingSampler.")
301
+ train_sampler = DynamicBucketingSampler(
302
+ cuts_train,
303
+ max_duration=self.args.max_duration,
304
+ shuffle=self.args.shuffle,
305
+ num_buckets=self.args.num_buckets,
306
+ drop_last=self.args.drop_last,
307
+ )
308
+ else:
309
+ logging.info("Using SingleCutSampler.")
310
+ train_sampler = SingleCutSampler(
311
+ cuts_train,
312
+ max_duration=self.args.max_duration,
313
+ shuffle=self.args.shuffle,
314
+ )
315
+ logging.info("About to create train dataloader")
316
+
317
+ if sampler_state_dict is not None:
318
+ logging.info("Loading sampler state dict")
319
+ train_sampler.load_state_dict(sampler_state_dict)
320
+
321
+ # 'seed' is derived from the current random state, which will have
322
+ # previously been set in the main process.
323
+ seed = torch.randint(0, 100000, ()).item()
324
+ worker_init_fn = _SeedWorkers(seed)
325
+
326
+ train_dl = DataLoader(
327
+ train,
328
+ sampler=train_sampler,
329
+ batch_size=None,
330
+ num_workers=self.args.num_workers,
331
+ persistent_workers=False,
332
+ worker_init_fn=worker_init_fn,
333
+ )
334
+
335
+ return train_dl
336
+
337
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
338
+ transforms = []
339
+ if self.args.concatenate_cuts:
340
+ transforms = [
341
+ CutConcatenate(
342
+ duration_factor=self.args.duration_factor, gap=self.args.gap
343
+ )
344
+ ] + transforms
345
+
346
+ logging.info("About to create dev dataset")
347
+ if self.args.on_the_fly_feats:
348
+ validate = K2SpeechRecognitionDataset(
349
+ cut_transforms=transforms,
350
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
351
+ return_cuts=self.args.return_cuts,
352
+ )
353
+ else:
354
+ validate = K2SpeechRecognitionDataset(
355
+ cut_transforms=transforms,
356
+ return_cuts=self.args.return_cuts,
357
+ )
358
+ valid_sampler = DynamicBucketingSampler(
359
+ cuts_valid,
360
+ max_duration=self.args.max_duration,
361
+ shuffle=False,
362
+ )
363
+ logging.info("About to create dev dataloader")
364
+ valid_dl = DataLoader(
365
+ validate,
366
+ sampler=valid_sampler,
367
+ batch_size=None,
368
+ num_workers=2,
369
+ persistent_workers=False,
370
+ )
371
+
372
+ return valid_dl
373
+
374
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
375
+ logging.debug("About to create test dataset")
376
+ test = K2SpeechRecognitionDataset(
377
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
378
+ if self.args.on_the_fly_feats
379
+ else eval(self.args.input_strategy)(),
380
+ return_cuts=self.args.return_cuts,
381
+ )
382
+ sampler = DynamicBucketingSampler(
383
+ cuts,
384
+ max_duration=self.args.max_duration,
385
+ shuffle=False,
386
+ )
387
+ logging.debug("About to create test dataloader")
388
+ test_dl = DataLoader(
389
+ test,
390
+ batch_size=None,
391
+ sampler=sampler,
392
+ num_workers=self.args.num_workers,
393
+ )
394
+ return test_dl
395
+
396
+ @lru_cache()
397
+ def train_clean_100_cuts(self) -> CutSet:
398
+ logging.info("About to get train-clean-100 cuts")
399
+ return load_manifest_lazy(
400
+ self.args.manifest_dir / "err2020_cuts_train.jsonl.gz"
401
+ )
402
+
403
+ # @lru_cache()
404
+ # def train_clean_360_cuts(self) -> CutSet:
405
+ # logging.info("About to get train-clean-360 cuts")
406
+ # return load_manifest_lazy(
407
+ # self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
408
+ # )
409
+
410
+ # @lru_cache()
411
+ # def train_other_500_cuts(self) -> CutSet:
412
+ # logging.info("About to get train-other-500 cuts")
413
+ # return load_manifest_lazy(
414
+ # self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
415
+ # )
416
+
417
+ @lru_cache()
418
+ def train_all_shuf_cuts(self) -> CutSet:
419
+ logging.info(
420
+ "About to get the shuffled train-clean-100, \
421
+ train-clean-360 and train-other-500 cuts"
422
+ )
423
+ return load_manifest_lazy(
424
+ # self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
425
+ self.args.manifest_dir / "err2020_cuts_train-all-shuf.jsonl.gz"
426
+ )
427
+
428
+ @lru_cache()
429
+ def dev_clean_cuts(self) -> CutSet:
430
+ logging.info("About to get dev-clean cuts")
431
+ return load_manifest_lazy(
432
+ # self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
433
+ self.args.manifest_dir / "err2020_cuts_validation.jsonl.gz"
434
+ )
435
+
436
+ # @lru_cache()
437
+ # def dev_other_cuts(self) -> CutSet:
438
+ # logging.info("About to get dev-other cuts")
439
+ # return load_manifest_lazy(
440
+ # # self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
441
+ # self.args.manifest_dir / "err2020_cuts_validation.jsonl.gz"
442
+ # )
443
+
444
+ @lru_cache()
445
+ def test_clean_cuts(self) -> CutSet:
446
+ logging.info("About to get test-clean cuts")
447
+ return load_manifest_lazy(
448
+ # self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
449
+ self.args.manifest_dir / "err2020_cuts_test.jsonl.gz"
450
+ )
451
+
452
+ # @lru_cache()
453
+ # def test_other_cuts(self) -> CutSet:
454
+ # logging.info("About to get test-other cuts")
455
+ # return load_manifest_lazy(
456
+ # self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
457
+ # self.args.manifest_dir / "err2020_cuts_test.jsonl.gz"
458
+ # )
err2020/conformer_ctc3/conformer.py ADDED
@@ -0,0 +1,1598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import copy
19
+ import math
20
+ import warnings
21
+ from typing import List, Optional, Tuple
22
+
23
+ import torch
24
+ from encoder_interface import EncoderInterface
25
+ from scaling import (
26
+ ActivationBalancer,
27
+ BasicNorm,
28
+ DoubleSwish,
29
+ ScaledConv1d,
30
+ ScaledConv2d,
31
+ ScaledLinear,
32
+ )
33
+ from torch import Tensor, nn
34
+
35
+ from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask
36
+
37
+
38
+ class Conformer(EncoderInterface):
39
+ """
40
+ Args:
41
+ num_features (int): Number of input features
42
+ subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
43
+ d_model (int): attention dimension, also the output dimension
44
+ nhead (int): number of head
45
+ dim_feedforward (int): feedforward dimention
46
+ num_encoder_layers (int): number of encoder layers
47
+ dropout (float): dropout rate
48
+ layer_dropout (float): layer-dropout rate.
49
+ cnn_module_kernel (int): Kernel size of convolution module
50
+ vgg_frontend (bool): whether to use vgg frontend.
51
+ dynamic_chunk_training (bool): whether to use dynamic chunk training, if
52
+ you want to train a streaming model, this is expected to be True.
53
+ When setting True, it will use a masking strategy to make the attention
54
+ see only limited left and right context.
55
+ short_chunk_threshold (float): a threshold to determinize the chunk size
56
+ to be used in masking training, if the randomly generated chunk size
57
+ is greater than ``max_len * short_chunk_threshold`` (max_len is the
58
+ max sequence length of current batch) then it will use
59
+ full context in training (i.e. with chunk size equals to max_len).
60
+ This will be used only when dynamic_chunk_training is True.
61
+ short_chunk_size (int): see docs above, if the randomly generated chunk
62
+ size equals to or less than ``max_len * short_chunk_threshold``, the
63
+ chunk size will be sampled uniformly from 1 to short_chunk_size.
64
+ This also will be used only when dynamic_chunk_training is True.
65
+ num_left_chunks (int): the left context (in chunks) attention can see, the
66
+ chunk size is decided by short_chunk_threshold and short_chunk_size.
67
+ A minus value means seeing full left context.
68
+ This also will be used only when dynamic_chunk_training is True.
69
+ causal (bool): Whether to use causal convolution in conformer encoder
70
+ layer. This MUST be True when using dynamic_chunk_training.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ num_features: int,
76
+ subsampling_factor: int = 4,
77
+ d_model: int = 256,
78
+ nhead: int = 4,
79
+ dim_feedforward: int = 2048,
80
+ num_encoder_layers: int = 12,
81
+ dropout: float = 0.1,
82
+ layer_dropout: float = 0.075,
83
+ cnn_module_kernel: int = 31,
84
+ dynamic_chunk_training: bool = False,
85
+ short_chunk_threshold: float = 0.75,
86
+ short_chunk_size: int = 25,
87
+ num_left_chunks: int = -1,
88
+ causal: bool = False,
89
+ ) -> None:
90
+ super(Conformer, self).__init__()
91
+
92
+ self.num_features = num_features
93
+ self.subsampling_factor = subsampling_factor
94
+ if subsampling_factor != 4:
95
+ raise NotImplementedError("Support only 'subsampling_factor=4'.")
96
+
97
+ # self.encoder_embed converts the input of shape (N, T, num_features)
98
+ # to the shape (N, T//subsampling_factor, d_model).
99
+ # That is, it does two things simultaneously:
100
+ # (1) subsampling: T -> T//subsampling_factor
101
+ # (2) embedding: num_features -> d_model
102
+ self.encoder_embed = Conv2dSubsampling(num_features, d_model)
103
+
104
+ self.encoder_layers = num_encoder_layers
105
+ self.d_model = d_model
106
+ self.cnn_module_kernel = cnn_module_kernel
107
+ self.causal = causal
108
+ self.dynamic_chunk_training = dynamic_chunk_training
109
+ self.short_chunk_threshold = short_chunk_threshold
110
+ self.short_chunk_size = short_chunk_size
111
+ self.num_left_chunks = num_left_chunks
112
+
113
+ self.encoder_pos = RelPositionalEncoding(d_model, dropout)
114
+
115
+ encoder_layer = ConformerEncoderLayer(
116
+ d_model,
117
+ nhead,
118
+ dim_feedforward,
119
+ dropout,
120
+ layer_dropout,
121
+ cnn_module_kernel,
122
+ causal,
123
+ )
124
+ self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
125
+ self._init_state: List[torch.Tensor] = [torch.empty(0)]
126
+
127
+ def forward(
128
+ self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
129
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
130
+ """
131
+ Args:
132
+ x:
133
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
134
+ x_lens:
135
+ A tensor of shape (batch_size,) containing the number of frames in
136
+ `x` before padding.
137
+ warmup:
138
+ A floating point value that gradually increases from 0 throughout
139
+ training; when it is >= 1.0 we are "fully warmed up". It is used
140
+ to turn modules on sequentially.
141
+ Returns:
142
+ Return a tuple containing 2 tensors:
143
+ - embeddings: its shape is (batch_size, output_seq_len, d_model)
144
+ - lengths, a tensor of shape (batch_size,) containing the number
145
+ of frames in `embeddings` before padding.
146
+ """
147
+ x = self.encoder_embed(x)
148
+ x, pos_emb = self.encoder_pos(x)
149
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
150
+
151
+ # Caution: We assume the subsampling factor is 4!
152
+
153
+ # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
154
+ #
155
+ # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
156
+ lengths = (((x_lens - 1) >> 1) - 1) >> 1
157
+
158
+ if not is_jit_tracing():
159
+ assert x.size(0) == lengths.max().item()
160
+
161
+ src_key_padding_mask = make_pad_mask(lengths)
162
+
163
+ if self.dynamic_chunk_training:
164
+ assert (
165
+ self.causal
166
+ ), "Causal convolution is required for streaming conformer."
167
+ max_len = x.size(0)
168
+ chunk_size = torch.randint(1, max_len, (1,)).item()
169
+ if chunk_size > (max_len * self.short_chunk_threshold):
170
+ chunk_size = max_len
171
+ else:
172
+ chunk_size = chunk_size % self.short_chunk_size + 1
173
+
174
+ mask = ~subsequent_chunk_mask(
175
+ size=x.size(0),
176
+ chunk_size=chunk_size,
177
+ num_left_chunks=self.num_left_chunks,
178
+ device=x.device,
179
+ )
180
+ x = self.encoder(
181
+ x,
182
+ pos_emb,
183
+ mask=mask,
184
+ src_key_padding_mask=src_key_padding_mask,
185
+ warmup=warmup,
186
+ ) # (T, N, C)
187
+ else:
188
+ x = self.encoder(
189
+ x,
190
+ pos_emb,
191
+ mask=None,
192
+ src_key_padding_mask=src_key_padding_mask,
193
+ warmup=warmup,
194
+ ) # (T, N, C)
195
+
196
+ x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
197
+ return x, lengths
198
+
199
+ @torch.jit.export
200
+ def get_init_state(
201
+ self, left_context: int, device: torch.device
202
+ ) -> List[torch.Tensor]:
203
+ """Return the initial cache state of the model.
204
+
205
+ Args:
206
+ left_context: The left context size (in frames after subsampling).
207
+
208
+ Returns:
209
+ Return the initial state of the model, it is a list containing two
210
+ tensors, the first one is the cache for attentions which has a shape
211
+ of (num_encoder_layers, left_context, encoder_dim), the second one
212
+ is the cache of conv_modules which has a shape of
213
+ (num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
214
+
215
+ NOTE: the returned tensors are on the given device.
216
+ """
217
+ if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
218
+ # Note: It is OK to share the init state as it is
219
+ # not going to be modified by the model
220
+ return self._init_state
221
+
222
+ init_states: List[torch.Tensor] = [
223
+ torch.zeros(
224
+ (
225
+ self.encoder_layers,
226
+ left_context,
227
+ self.d_model,
228
+ ),
229
+ device=device,
230
+ ),
231
+ torch.zeros(
232
+ (
233
+ self.encoder_layers,
234
+ self.cnn_module_kernel - 1,
235
+ self.d_model,
236
+ ),
237
+ device=device,
238
+ ),
239
+ ]
240
+
241
+ self._init_state = init_states
242
+
243
+ return init_states
244
+
245
+ @torch.jit.export
246
+ def streaming_forward(
247
+ self,
248
+ x: torch.Tensor,
249
+ x_lens: torch.Tensor,
250
+ states: Optional[List[Tensor]] = None,
251
+ processed_lens: Optional[Tensor] = None,
252
+ left_context: int = 64,
253
+ right_context: int = 4,
254
+ chunk_size: int = 16,
255
+ simulate_streaming: bool = False,
256
+ warmup: float = 1.0,
257
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
258
+ """
259
+ Args:
260
+ x:
261
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
262
+ x_lens:
263
+ A tensor of shape (batch_size,) containing the number of frames in
264
+ `x` before padding.
265
+ states:
266
+ The decode states for previous frames which contains the cached data.
267
+ It has two elements, the first element is the attn_cache which has
268
+ a shape of (encoder_layers, left_context, batch, attention_dim),
269
+ the second element is the conv_cache which has a shape of
270
+ (encoder_layers, cnn_module_kernel-1, batch, conv_dim).
271
+ Note: states will be modified in this function.
272
+ processed_lens:
273
+ How many frames (after subsampling) have been processed for each sequence.
274
+ left_context:
275
+ How many previous frames the attention can see in current chunk.
276
+ Note: It's not that each individual frame has `left_context` frames
277
+ of left context, some have more.
278
+ right_context:
279
+ How many future frames the attention can see in current chunk.
280
+ Note: It's not that each individual frame has `right_context` frames
281
+ of right context, some have more.
282
+ chunk_size:
283
+ The chunk size for decoding, this will be used to simulate streaming
284
+ decoding using masking.
285
+ simulate_streaming:
286
+ If setting True, it will use a masking strategy to simulate streaming
287
+ fashion (i.e. every chunk data only see limited left context and
288
+ right context). The whole sequence is supposed to be send at a time
289
+ When using simulate_streaming.
290
+ warmup:
291
+ A floating point value that gradually increases from 0 throughout
292
+ training; when it is >= 1.0 we are "fully warmed up". It is used
293
+ to turn modules on sequentially.
294
+ Returns:
295
+ Return a tuple containing 2 tensors:
296
+ - logits, its shape is (batch_size, output_seq_len, output_dim)
297
+ - logit_lens, a tensor of shape (batch_size,) containing the number
298
+ of frames in `logits` before padding.
299
+ - decode_states, the updated states including the information
300
+ of current chunk.
301
+ """
302
+
303
+ # x: [N, T, C]
304
+ # Caution: We assume the subsampling factor is 4!
305
+
306
+ # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
307
+ #
308
+ # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
309
+ lengths = (((x_lens - 1) >> 1) - 1) >> 1
310
+
311
+ if not simulate_streaming:
312
+ assert states is not None
313
+ assert processed_lens is not None
314
+ assert (
315
+ len(states) == 2
316
+ and states[0].shape
317
+ == (self.encoder_layers, left_context, x.size(0), self.d_model)
318
+ and states[1].shape
319
+ == (
320
+ self.encoder_layers,
321
+ self.cnn_module_kernel - 1,
322
+ x.size(0),
323
+ self.d_model,
324
+ )
325
+ ), f"""The length of states MUST be equal to 2, and the shape of
326
+ first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
327
+ given {states[0].shape}. the shape of second element should be
328
+ {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
329
+ given {states[1].shape}."""
330
+
331
+ lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
332
+
333
+ src_key_padding_mask = make_pad_mask(lengths)
334
+
335
+ processed_mask = torch.arange(left_context, device=x.device).expand(
336
+ x.size(0), left_context
337
+ )
338
+ processed_lens = processed_lens.view(x.size(0), 1)
339
+ processed_mask = (processed_lens <= processed_mask).flip(1)
340
+
341
+ src_key_padding_mask = torch.cat(
342
+ [processed_mask, src_key_padding_mask], dim=1
343
+ )
344
+
345
+ embed = self.encoder_embed(x)
346
+
347
+ # cut off 1 frame on each size of embed as they see the padding
348
+ # value which causes a training and decoding mismatch.
349
+ embed = embed[:, 1:-1, :]
350
+
351
+ embed, pos_enc = self.encoder_pos(embed, left_context)
352
+ embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
353
+
354
+ x, states = self.encoder.chunk_forward(
355
+ embed,
356
+ pos_enc,
357
+ src_key_padding_mask=src_key_padding_mask,
358
+ warmup=warmup,
359
+ states=states,
360
+ left_context=left_context,
361
+ right_context=right_context,
362
+ ) # (T, B, F)
363
+ if right_context > 0:
364
+ x = x[0:-right_context, ...]
365
+ lengths -= right_context
366
+ else:
367
+ assert states is None
368
+ states = [] # just to make torch.script.jit happy
369
+ # this branch simulates streaming decoding using mask as we are
370
+ # using in training time.
371
+ src_key_padding_mask = make_pad_mask(lengths)
372
+ x = self.encoder_embed(x)
373
+ x, pos_emb = self.encoder_pos(x)
374
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
375
+
376
+ assert x.size(0) == lengths.max().item()
377
+
378
+ if chunk_size < 0:
379
+ # use full attention
380
+ chunk_size = x.size(0)
381
+ left_context = -1
382
+
383
+ num_left_chunks = -1
384
+ if left_context >= 0:
385
+ assert left_context % chunk_size == 0
386
+ num_left_chunks = left_context // chunk_size
387
+
388
+ mask = ~subsequent_chunk_mask(
389
+ size=x.size(0),
390
+ chunk_size=chunk_size,
391
+ num_left_chunks=num_left_chunks,
392
+ device=x.device,
393
+ )
394
+ x = self.encoder(
395
+ x,
396
+ pos_emb,
397
+ mask=mask,
398
+ src_key_padding_mask=src_key_padding_mask,
399
+ warmup=warmup,
400
+ ) # (T, N, C)
401
+
402
+ x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
403
+
404
+ return x, lengths, states
405
+
406
+
407
+ class ConformerEncoderLayer(nn.Module):
408
+ """
409
+ ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
410
+ See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
411
+
412
+ Args:
413
+ d_model: the number of expected features in the input (required).
414
+ nhead: the number of heads in the multiheadattention models (required).
415
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
416
+ dropout: the dropout value (default=0.1).
417
+ cnn_module_kernel (int): Kernel size of convolution module.
418
+ causal (bool): Whether to use causal convolution in conformer encoder
419
+ layer. This MUST be True when using dynamic_chunk_training and streaming decoding.
420
+
421
+ Examples::
422
+ >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
423
+ >>> src = torch.rand(10, 32, 512)
424
+ >>> pos_emb = torch.rand(32, 19, 512)
425
+ >>> out = encoder_layer(src, pos_emb)
426
+ """
427
+
428
+ def __init__(
429
+ self,
430
+ d_model: int,
431
+ nhead: int,
432
+ dim_feedforward: int = 2048,
433
+ dropout: float = 0.1,
434
+ layer_dropout: float = 0.075,
435
+ cnn_module_kernel: int = 31,
436
+ causal: bool = False,
437
+ ) -> None:
438
+ super(ConformerEncoderLayer, self).__init__()
439
+
440
+ self.layer_dropout = layer_dropout
441
+
442
+ self.d_model = d_model
443
+
444
+ self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
445
+
446
+ self.feed_forward = nn.Sequential(
447
+ ScaledLinear(d_model, dim_feedforward),
448
+ ActivationBalancer(channel_dim=-1),
449
+ DoubleSwish(),
450
+ nn.Dropout(dropout),
451
+ ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
452
+ )
453
+
454
+ self.feed_forward_macaron = nn.Sequential(
455
+ ScaledLinear(d_model, dim_feedforward),
456
+ ActivationBalancer(channel_dim=-1),
457
+ DoubleSwish(),
458
+ nn.Dropout(dropout),
459
+ ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
460
+ )
461
+
462
+ self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
463
+
464
+ self.norm_final = BasicNorm(d_model)
465
+
466
+ # try to ensure the output is close to zero-mean (or at least, zero-median).
467
+ self.balancer = ActivationBalancer(
468
+ channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
469
+ )
470
+
471
+ self.dropout = nn.Dropout(dropout)
472
+
473
+ def forward(
474
+ self,
475
+ src: Tensor,
476
+ pos_emb: Tensor,
477
+ src_key_padding_mask: Optional[Tensor] = None,
478
+ src_mask: Optional[Tensor] = None,
479
+ warmup: float = 1.0,
480
+ ) -> Tensor:
481
+ """
482
+ Pass the input through the encoder layer.
483
+
484
+ Args:
485
+ src: the sequence to the encoder layer (required).
486
+ pos_emb: Positional embedding tensor (required).
487
+ src_key_padding_mask: the mask for the src keys per batch (optional).
488
+ src_mask: the mask for the src sequence (optional).
489
+ warmup: controls selective bypass of of layers; if < 1.0, we will
490
+ bypass layers more frequently.
491
+ Shape:
492
+ src: (S, N, E).
493
+ pos_emb: (N, 2*S-1, E)
494
+ src_mask: (S, S).
495
+ src_key_padding_mask: (N, S).
496
+ S is the source sequence length, N is the batch size, E is the feature number
497
+ """
498
+ src_orig = src
499
+
500
+ warmup_scale = min(0.1 + warmup, 1.0)
501
+ # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
502
+ # completely bypass it.
503
+ if self.training:
504
+ alpha = (
505
+ warmup_scale
506
+ if torch.rand(()).item() <= (1.0 - self.layer_dropout)
507
+ else 0.1
508
+ )
509
+ else:
510
+ alpha = 1.0
511
+
512
+ # macaron style feed forward module
513
+ src = src + self.dropout(self.feed_forward_macaron(src))
514
+
515
+ # multi-headed self-attention module
516
+ src_att = self.self_attn(
517
+ src,
518
+ src,
519
+ src,
520
+ pos_emb=pos_emb,
521
+ attn_mask=src_mask,
522
+ key_padding_mask=src_key_padding_mask,
523
+ )[0]
524
+
525
+ src = src + self.dropout(src_att)
526
+
527
+ # convolution module
528
+ conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
529
+ src = src + self.dropout(conv)
530
+
531
+ # feed forward module
532
+ src = src + self.dropout(self.feed_forward(src))
533
+
534
+ src = self.norm_final(self.balancer(src))
535
+
536
+ if alpha != 1.0:
537
+ src = alpha * src + (1 - alpha) * src_orig
538
+
539
+ return src
540
+
541
+ @torch.jit.export
542
+ def chunk_forward(
543
+ self,
544
+ src: Tensor,
545
+ pos_emb: Tensor,
546
+ states: List[Tensor],
547
+ src_mask: Optional[Tensor] = None,
548
+ src_key_padding_mask: Optional[Tensor] = None,
549
+ warmup: float = 1.0,
550
+ left_context: int = 0,
551
+ right_context: int = 0,
552
+ ) -> Tuple[Tensor, List[Tensor]]:
553
+ """
554
+ Pass the input through the encoder layer.
555
+
556
+ Args:
557
+ src: the sequence to the encoder layer (required).
558
+ pos_emb: Positional embedding tensor (required).
559
+ states:
560
+ The decode states for previous frames which contains the cached data.
561
+ It has two elements, the first element is the attn_cache which has
562
+ a shape of (left_context, batch, attention_dim),
563
+ the second element is the conv_cache which has a shape of
564
+ (cnn_module_kernel-1, batch, conv_dim).
565
+ Note: states will be modified in this function.
566
+ src_mask: the mask for the src sequence (optional).
567
+ src_key_padding_mask: the mask for the src keys per batch (optional).
568
+ warmup: controls selective bypass of of layers; if < 1.0, we will
569
+ bypass layers more frequently.
570
+ left_context:
571
+ How many previous frames the attention can see in current chunk.
572
+ Note: It's not that each individual frame has `left_context` frames
573
+ of left context, some have more.
574
+ right_context:
575
+ How many future frames the attention can see in current chunk.
576
+ Note: It's not that each individual frame has `right_context` frames
577
+ of right context, some have more.
578
+
579
+ Shape:
580
+ src: (S, N, E).
581
+ pos_emb: (N, 2*(S+left_context)-1, E).
582
+ src_mask: (S, S).
583
+ src_key_padding_mask: (N, S).
584
+ S is the source sequence length, N is the batch size, E is the feature number
585
+ """
586
+
587
+ assert not self.training
588
+ assert len(states) == 2
589
+ assert states[0].shape == (left_context, src.size(1), src.size(2))
590
+
591
+ # macaron style feed forward module
592
+ src = src + self.dropout(self.feed_forward_macaron(src))
593
+
594
+ # We put the attention cache this level (i.e. before linear transformation)
595
+ # to save memory consumption, when decoding in streaming fashion, the
596
+ # batch size would be thousands (for 32GB machine), if we cache key & val
597
+ # separately, it needs extra several GB memory.
598
+ # TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val
599
+ # separately) if needed.
600
+ key = torch.cat([states[0], src], dim=0)
601
+ val = key
602
+ if right_context > 0:
603
+ states[0] = key[
604
+ -(left_context + right_context) : -right_context, ... # noqa
605
+ ]
606
+ else:
607
+ states[0] = key[-left_context:, ...]
608
+
609
+ # multi-headed self-attention module
610
+ src_att = self.self_attn(
611
+ src,
612
+ key,
613
+ val,
614
+ pos_emb=pos_emb,
615
+ attn_mask=src_mask,
616
+ key_padding_mask=src_key_padding_mask,
617
+ left_context=left_context,
618
+ )[0]
619
+
620
+ src = src + self.dropout(src_att)
621
+
622
+ # convolution module
623
+ conv, conv_cache = self.conv_module(src, states[1], right_context)
624
+ states[1] = conv_cache
625
+
626
+ src = src + self.dropout(conv)
627
+
628
+ # feed forward module
629
+ src = src + self.dropout(self.feed_forward(src))
630
+
631
+ src = self.norm_final(self.balancer(src))
632
+
633
+ return src, states
634
+
635
+
636
+ class ConformerEncoder(nn.Module):
637
+ r"""ConformerEncoder is a stack of N encoder layers
638
+
639
+ Args:
640
+ encoder_layer: an instance of the ConformerEncoderLayer() class (required).
641
+ num_layers: the number of sub-encoder-layers in the encoder (required).
642
+
643
+ Examples::
644
+ >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
645
+ >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
646
+ >>> src = torch.rand(10, 32, 512)
647
+ >>> pos_emb = torch.rand(32, 19, 512)
648
+ >>> out = conformer_encoder(src, pos_emb)
649
+ """
650
+
651
+ def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
652
+ super().__init__()
653
+ self.layers = nn.ModuleList(
654
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
655
+ )
656
+ self.num_layers = num_layers
657
+
658
+ def forward(
659
+ self,
660
+ src: Tensor,
661
+ pos_emb: Tensor,
662
+ src_key_padding_mask: Optional[Tensor] = None,
663
+ mask: Optional[Tensor] = None,
664
+ warmup: float = 1.0,
665
+ ) -> Tensor:
666
+ r"""Pass the input through the encoder layers in turn.
667
+
668
+ Args:
669
+ src: the sequence to the encoder (required).
670
+ pos_emb: Positional embedding tensor (required).
671
+ src_key_padding_mask: the mask for the src keys per batch (optional).
672
+ mask: the mask for the src sequence (optional).
673
+ warmup: controls selective bypass of of layers; if < 1.0, we will
674
+ bypass layers more frequently.
675
+
676
+ Shape:
677
+ src: (S, N, E).
678
+ pos_emb: (N, 2*S-1, E)
679
+ mask: (S, S).
680
+ src_key_padding_mask: (N, S).
681
+ S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
682
+
683
+ """
684
+ output = src
685
+
686
+ for layer_index, mod in enumerate(self.layers):
687
+ output = mod(
688
+ output,
689
+ pos_emb,
690
+ src_mask=mask,
691
+ src_key_padding_mask=src_key_padding_mask,
692
+ warmup=warmup,
693
+ )
694
+
695
+ return output
696
+
697
+ @torch.jit.export
698
+ def chunk_forward(
699
+ self,
700
+ src: Tensor,
701
+ pos_emb: Tensor,
702
+ states: List[Tensor],
703
+ mask: Optional[Tensor] = None,
704
+ src_key_padding_mask: Optional[Tensor] = None,
705
+ warmup: float = 1.0,
706
+ left_context: int = 0,
707
+ right_context: int = 0,
708
+ ) -> Tuple[Tensor, List[Tensor]]:
709
+ r"""Pass the input through the encoder layers in turn.
710
+
711
+ Args:
712
+ src: the sequence to the encoder (required).
713
+ pos_emb: Positional embedding tensor (required).
714
+ states:
715
+ The decode states for previous frames which contains the cached data.
716
+ It has two elements, the first element is the attn_cache which has
717
+ a shape of (encoder_layers, left_context, batch, attention_dim),
718
+ the second element is the conv_cache which has a shape of
719
+ (encoder_layers, cnn_module_kernel-1, batch, conv_dim).
720
+ Note: states will be modified in this function.
721
+ mask: the mask for the src sequence (optional).
722
+ src_key_padding_mask: the mask for the src keys per batch (optional).
723
+ warmup: controls selective bypass of of layers; if < 1.0, we will
724
+ bypass layers more frequently.
725
+ left_context:
726
+ How many previous frames the attention can see in current chunk.
727
+ Note: It's not that each individual frame has `left_context` frames
728
+ of left context, some have more.
729
+ right_context:
730
+ How many future frames the attention can see in current chunk.
731
+ Note: It's not that each individual frame has `right_context` frames
732
+ of right context, some have more.
733
+ Shape:
734
+ src: (S, N, E).
735
+ pos_emb: (N, 2*(S+left_context)-1, E).
736
+ mask: (S, S).
737
+ src_key_padding_mask: (N, S).
738
+ S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
739
+
740
+ """
741
+ assert not self.training
742
+ assert len(states) == 2
743
+ assert states[0].shape == (
744
+ self.num_layers,
745
+ left_context,
746
+ src.size(1),
747
+ src.size(2),
748
+ )
749
+ assert states[1].size(0) == self.num_layers
750
+
751
+ output = src
752
+
753
+ for layer_index, mod in enumerate(self.layers):
754
+ cache = [states[0][layer_index], states[1][layer_index]]
755
+ output, cache = mod.chunk_forward(
756
+ output,
757
+ pos_emb,
758
+ states=cache,
759
+ src_mask=mask,
760
+ src_key_padding_mask=src_key_padding_mask,
761
+ warmup=warmup,
762
+ left_context=left_context,
763
+ right_context=right_context,
764
+ )
765
+ states[0][layer_index] = cache[0]
766
+ states[1][layer_index] = cache[1]
767
+
768
+ return output, states
769
+
770
+
771
+ class RelPositionalEncoding(torch.nn.Module):
772
+ """Relative positional encoding module.
773
+
774
+ See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
775
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
776
+
777
+ Args:
778
+ d_model: Embedding dimension.
779
+ dropout_rate: Dropout rate.
780
+ max_len: Maximum input length.
781
+
782
+ """
783
+
784
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
785
+ """Construct an PositionalEncoding object."""
786
+ super(RelPositionalEncoding, self).__init__()
787
+ if is_jit_tracing():
788
+ # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
789
+ # It assumes that the maximum input won't have more than
790
+ # 10k frames.
791
+ #
792
+ # TODO(fangjun): Use torch.jit.script() for this module
793
+ max_len = 10000
794
+
795
+ self.d_model = d_model
796
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
797
+ self.pe = None
798
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
799
+
800
+ def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
801
+ """Reset the positional encodings."""
802
+ x_size_1 = x.size(1) + left_context
803
+ if self.pe is not None:
804
+ # self.pe contains both positive and negative parts
805
+ # the length of self.pe is 2 * input_len - 1
806
+ if self.pe.size(1) >= x_size_1 * 2 - 1:
807
+ # Note: TorchScript doesn't implement operator== for torch.Device
808
+ if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
809
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
810
+ return
811
+ # Suppose `i` means to the position of query vector and `j` means the
812
+ # position of key vector. We use position relative positions when keys
813
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
814
+ pe_positive = torch.zeros(x_size_1, self.d_model)
815
+ pe_negative = torch.zeros(x_size_1, self.d_model)
816
+ position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
817
+ div_term = torch.exp(
818
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
819
+ * -(math.log(10000.0) / self.d_model)
820
+ )
821
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
822
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
823
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
824
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
825
+
826
+ # Reserve the order of positive indices and concat both positive and
827
+ # negative indices. This is used to support the shifting trick
828
+ # as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
829
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
830
+ pe_negative = pe_negative[1:].unsqueeze(0)
831
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
832
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
833
+
834
+ def forward(
835
+ self,
836
+ x: torch.Tensor,
837
+ left_context: int = 0,
838
+ ) -> Tuple[Tensor, Tensor]:
839
+ """Add positional encoding.
840
+
841
+ Args:
842
+ x (torch.Tensor): Input tensor (batch, time, `*`).
843
+ left_context (int): left context (in frames) used during streaming decoding.
844
+ this is used only in real streaming decoding, in other circumstances,
845
+ it MUST be 0.
846
+
847
+ Returns:
848
+ torch.Tensor: Encoded tensor (batch, time, `*`).
849
+ torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
850
+
851
+ """
852
+ self.extend_pe(x, left_context)
853
+ x_size_1 = x.size(1) + left_context
854
+ pos_emb = self.pe[
855
+ :,
856
+ self.pe.size(1) // 2
857
+ - x_size_1
858
+ + 1 : self.pe.size(1) // 2 # noqa E203
859
+ + x.size(1),
860
+ ]
861
+ return self.dropout(x), self.dropout(pos_emb)
862
+
863
+
864
+ class RelPositionMultiheadAttention(nn.Module):
865
+ r"""Multi-Head Attention layer with relative position encoding
866
+
867
+ See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
868
+
869
+ Args:
870
+ embed_dim: total dimension of the model.
871
+ num_heads: parallel attention heads.
872
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
873
+
874
+ Examples::
875
+
876
+ >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
877
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
878
+ """
879
+
880
+ def __init__(
881
+ self,
882
+ embed_dim: int,
883
+ num_heads: int,
884
+ dropout: float = 0.0,
885
+ ) -> None:
886
+ super(RelPositionMultiheadAttention, self).__init__()
887
+ self.embed_dim = embed_dim
888
+ self.num_heads = num_heads
889
+ self.dropout = dropout
890
+ self.head_dim = embed_dim // num_heads
891
+ assert (
892
+ self.head_dim * num_heads == self.embed_dim
893
+ ), "embed_dim must be divisible by num_heads"
894
+
895
+ self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
896
+ self.out_proj = ScaledLinear(
897
+ embed_dim, embed_dim, bias=True, initial_scale=0.25
898
+ )
899
+
900
+ # linear transformation for positional encoding.
901
+ self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
902
+ # these two learnable bias are used in matrix c and matrix d
903
+ # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
904
+ self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
905
+ self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
906
+ self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
907
+ self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
908
+ self._reset_parameters()
909
+
910
+ def _pos_bias_u(self):
911
+ return self.pos_bias_u * self.pos_bias_u_scale.exp()
912
+
913
+ def _pos_bias_v(self):
914
+ return self.pos_bias_v * self.pos_bias_v_scale.exp()
915
+
916
+ def _reset_parameters(self) -> None:
917
+ nn.init.normal_(self.pos_bias_u, std=0.01)
918
+ nn.init.normal_(self.pos_bias_v, std=0.01)
919
+
920
+ def forward(
921
+ self,
922
+ query: Tensor,
923
+ key: Tensor,
924
+ value: Tensor,
925
+ pos_emb: Tensor,
926
+ key_padding_mask: Optional[Tensor] = None,
927
+ need_weights: bool = False,
928
+ attn_mask: Optional[Tensor] = None,
929
+ left_context: int = 0,
930
+ ) -> Tuple[Tensor, Optional[Tensor]]:
931
+ r"""
932
+ Args:
933
+ query, key, value: map a query and a set of key-value pairs to an output.
934
+ pos_emb: Positional embedding tensor
935
+ key_padding_mask: if provided, specified padding elements in the key will
936
+ be ignored by the attention. When given a binary mask and a value is True,
937
+ the corresponding value on the attention layer will be ignored. When given
938
+ a byte mask and a value is non-zero, the corresponding value on the attention
939
+ layer will be ignored
940
+ need_weights: output attn_output_weights.
941
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
942
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
943
+ left_context (int): left context (in frames) used during streaming decoding.
944
+ this is used only in real streaming decoding, in other circumstances,
945
+ it MUST be 0.
946
+
947
+ Shape:
948
+ - Inputs:
949
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
950
+ the embedding dimension.
951
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
952
+ the embedding dimension.
953
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
954
+ the embedding dimension.
955
+ - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
956
+ the embedding dimension.
957
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
958
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
959
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
960
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
961
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
962
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
963
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
964
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
965
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
966
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
967
+ is provided, it will be added to the attention weight.
968
+
969
+ - Outputs:
970
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
971
+ E is the embedding dimension.
972
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
973
+ L is the target sequence length, S is the source sequence length.
974
+ """
975
+ return self.multi_head_attention_forward(
976
+ query,
977
+ key,
978
+ value,
979
+ pos_emb,
980
+ self.embed_dim,
981
+ self.num_heads,
982
+ self.in_proj.get_weight(),
983
+ self.in_proj.get_bias(),
984
+ self.dropout,
985
+ self.out_proj.get_weight(),
986
+ self.out_proj.get_bias(),
987
+ training=self.training,
988
+ key_padding_mask=key_padding_mask,
989
+ need_weights=need_weights,
990
+ attn_mask=attn_mask,
991
+ left_context=left_context,
992
+ )
993
+
994
+ def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
995
+ """Compute relative positional encoding.
996
+
997
+ Args:
998
+ x: Input tensor (batch, head, time1, 2*time1-1+left_context).
999
+ time1 means the length of query vector.
1000
+ left_context (int): left context (in frames) used during streaming decoding.
1001
+ this is used only in real streaming decoding, in other circumstances,
1002
+ it MUST be 0.
1003
+
1004
+ Returns:
1005
+ Tensor: tensor of shape (batch, head, time1, time2)
1006
+ (note: time2 has the same value as time1, but it is for
1007
+ the key, while time1 is for the query).
1008
+ """
1009
+ (batch_size, num_heads, time1, n) = x.shape
1010
+
1011
+ time2 = time1 + left_context
1012
+ if not is_jit_tracing():
1013
+ assert (
1014
+ n == left_context + 2 * time1 - 1
1015
+ ), f"{n} == {left_context} + 2 * {time1} - 1"
1016
+
1017
+ if is_jit_tracing():
1018
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
1019
+ cols = torch.arange(time2)
1020
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
1021
+ indexes = rows + cols
1022
+
1023
+ x = x.reshape(-1, n)
1024
+ x = torch.gather(x, dim=1, index=indexes)
1025
+ x = x.reshape(batch_size, num_heads, time1, time2)
1026
+ return x
1027
+ else:
1028
+ # Note: TorchScript requires explicit arg for stride()
1029
+ batch_stride = x.stride(0)
1030
+ head_stride = x.stride(1)
1031
+ time1_stride = x.stride(2)
1032
+ n_stride = x.stride(3)
1033
+ return x.as_strided(
1034
+ (batch_size, num_heads, time1, time2),
1035
+ (batch_stride, head_stride, time1_stride - n_stride, n_stride),
1036
+ storage_offset=n_stride * (time1 - 1),
1037
+ )
1038
+
1039
+ def multi_head_attention_forward(
1040
+ self,
1041
+ query: Tensor,
1042
+ key: Tensor,
1043
+ value: Tensor,
1044
+ pos_emb: Tensor,
1045
+ embed_dim_to_check: int,
1046
+ num_heads: int,
1047
+ in_proj_weight: Tensor,
1048
+ in_proj_bias: Tensor,
1049
+ dropout_p: float,
1050
+ out_proj_weight: Tensor,
1051
+ out_proj_bias: Tensor,
1052
+ training: bool = True,
1053
+ key_padding_mask: Optional[Tensor] = None,
1054
+ need_weights: bool = False,
1055
+ attn_mask: Optional[Tensor] = None,
1056
+ left_context: int = 0,
1057
+ ) -> Tuple[Tensor, Optional[Tensor]]:
1058
+ r"""
1059
+ Args:
1060
+ query, key, value: map a query and a set of key-value pairs to an output.
1061
+ pos_emb: Positional embedding tensor
1062
+ embed_dim_to_check: total dimension of the model.
1063
+ num_heads: parallel attention heads.
1064
+ in_proj_weight, in_proj_bias: input projection weight and bias.
1065
+ dropout_p: probability of an element to be zeroed.
1066
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
1067
+ training: apply dropout if is ``True``.
1068
+ key_padding_mask: if provided, specified padding elements in the key will
1069
+ be ignored by the attention. This is an binary mask. When the value is True,
1070
+ the corresponding value on the attention layer will be filled with -inf.
1071
+ need_weights: output attn_output_weights.
1072
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
1073
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
1074
+ left_context (int): left context (in frames) used during streaming decoding.
1075
+ this is used only in real streaming decoding, in other circumstances,
1076
+ it MUST be 0.
1077
+
1078
+ Shape:
1079
+ Inputs:
1080
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
1081
+ the embedding dimension.
1082
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
1083
+ the embedding dimension.
1084
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
1085
+ the embedding dimension.
1086
+ - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
1087
+ length, N is the batch size, E is the embedding dimension.
1088
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
1089
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
1090
+ will be unchanged. If a BoolTensor is provided, the positions with the
1091
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
1092
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
1093
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
1094
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
1095
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
1096
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
1097
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
1098
+ is provided, it will be added to the attention weight.
1099
+
1100
+ Outputs:
1101
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
1102
+ E is the embedding dimension.
1103
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
1104
+ L is the target sequence length, S is the source sequence length.
1105
+ """
1106
+
1107
+ tgt_len, bsz, embed_dim = query.size()
1108
+ if not is_jit_tracing():
1109
+ assert embed_dim == embed_dim_to_check
1110
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
1111
+
1112
+ head_dim = embed_dim // num_heads
1113
+ if not is_jit_tracing():
1114
+ assert (
1115
+ head_dim * num_heads == embed_dim
1116
+ ), "embed_dim must be divisible by num_heads"
1117
+
1118
+ scaling = float(head_dim) ** -0.5
1119
+
1120
+ if torch.equal(query, key) and torch.equal(key, value):
1121
+ # self-attention
1122
+ q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
1123
+ 3, dim=-1
1124
+ )
1125
+
1126
+ elif torch.equal(key, value):
1127
+ # encoder-decoder attention
1128
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
1129
+ _b = in_proj_bias
1130
+ _start = 0
1131
+ _end = embed_dim
1132
+ _w = in_proj_weight[_start:_end, :]
1133
+ if _b is not None:
1134
+ _b = _b[_start:_end]
1135
+ q = nn.functional.linear(query, _w, _b)
1136
+
1137
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
1138
+ _b = in_proj_bias
1139
+ _start = embed_dim
1140
+ _end = None
1141
+ _w = in_proj_weight[_start:, :]
1142
+ if _b is not None:
1143
+ _b = _b[_start:]
1144
+ k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
1145
+
1146
+ else:
1147
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
1148
+ _b = in_proj_bias
1149
+ _start = 0
1150
+ _end = embed_dim
1151
+ _w = in_proj_weight[_start:_end, :]
1152
+ if _b is not None:
1153
+ _b = _b[_start:_end]
1154
+ q = nn.functional.linear(query, _w, _b)
1155
+
1156
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
1157
+ _b = in_proj_bias
1158
+ _start = embed_dim
1159
+ _end = embed_dim * 2
1160
+ _w = in_proj_weight[_start:_end, :]
1161
+ if _b is not None:
1162
+ _b = _b[_start:_end]
1163
+ k = nn.functional.linear(key, _w, _b)
1164
+
1165
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
1166
+ _b = in_proj_bias
1167
+ _start = embed_dim * 2
1168
+ _end = None
1169
+ _w = in_proj_weight[_start:, :]
1170
+ if _b is not None:
1171
+ _b = _b[_start:]
1172
+ v = nn.functional.linear(value, _w, _b)
1173
+
1174
+ if attn_mask is not None:
1175
+ assert (
1176
+ attn_mask.dtype == torch.float32
1177
+ or attn_mask.dtype == torch.float64
1178
+ or attn_mask.dtype == torch.float16
1179
+ or attn_mask.dtype == torch.uint8
1180
+ or attn_mask.dtype == torch.bool
1181
+ ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
1182
+ attn_mask.dtype
1183
+ )
1184
+ if attn_mask.dtype == torch.uint8:
1185
+ warnings.warn(
1186
+ "Byte tensor for attn_mask is deprecated. Use bool tensor instead."
1187
+ )
1188
+ attn_mask = attn_mask.to(torch.bool)
1189
+
1190
+ if attn_mask.dim() == 2:
1191
+ attn_mask = attn_mask.unsqueeze(0)
1192
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
1193
+ raise RuntimeError("The size of the 2D attn_mask is not correct.")
1194
+ elif attn_mask.dim() == 3:
1195
+ if list(attn_mask.size()) != [
1196
+ bsz * num_heads,
1197
+ query.size(0),
1198
+ key.size(0),
1199
+ ]:
1200
+ raise RuntimeError("The size of the 3D attn_mask is not correct.")
1201
+ else:
1202
+ raise RuntimeError(
1203
+ "attn_mask's dimension {} is not supported".format(attn_mask.dim())
1204
+ )
1205
+ # attn_mask's dim is 3 now.
1206
+
1207
+ # convert ByteTensor key_padding_mask to bool
1208
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
1209
+ warnings.warn(
1210
+ "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
1211
+ )
1212
+ key_padding_mask = key_padding_mask.to(torch.bool)
1213
+
1214
+ q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
1215
+ k = k.contiguous().view(-1, bsz, num_heads, head_dim)
1216
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
1217
+
1218
+ src_len = k.size(0)
1219
+
1220
+ if key_padding_mask is not None and not is_jit_tracing():
1221
+ assert key_padding_mask.size(0) == bsz, "{} == {}".format(
1222
+ key_padding_mask.size(0), bsz
1223
+ )
1224
+ assert key_padding_mask.size(1) == src_len, "{} == {}".format(
1225
+ key_padding_mask.size(1), src_len
1226
+ )
1227
+
1228
+ q = q.transpose(0, 1) # (batch, time1, head, d_k)
1229
+
1230
+ pos_emb_bsz = pos_emb.size(0)
1231
+ if not is_jit_tracing():
1232
+ assert pos_emb_bsz in (1, bsz) # actually it is 1
1233
+
1234
+ p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
1235
+ # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
1236
+ p = p.permute(0, 2, 3, 1)
1237
+
1238
+ q_with_bias_u = (q + self._pos_bias_u()).transpose(
1239
+ 1, 2
1240
+ ) # (batch, head, time1, d_k)
1241
+
1242
+ q_with_bias_v = (q + self._pos_bias_v()).transpose(
1243
+ 1, 2
1244
+ ) # (batch, head, time1, d_k)
1245
+
1246
+ # compute attention score
1247
+ # first compute matrix a and matrix c
1248
+ # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
1249
+ k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
1250
+ matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
1251
+
1252
+ # compute matrix b and matrix d
1253
+ matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1)
1254
+ matrix_bd = self.rel_shift(matrix_bd, left_context)
1255
+
1256
+ attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2)
1257
+
1258
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
1259
+
1260
+ if not is_jit_tracing():
1261
+ assert list(attn_output_weights.size()) == [
1262
+ bsz * num_heads,
1263
+ tgt_len,
1264
+ src_len,
1265
+ ]
1266
+
1267
+ if attn_mask is not None:
1268
+ if attn_mask.dtype == torch.bool:
1269
+ attn_output_weights.masked_fill_(attn_mask, float("-inf"))
1270
+ else:
1271
+ attn_output_weights += attn_mask
1272
+
1273
+ if key_padding_mask is not None:
1274
+ attn_output_weights = attn_output_weights.view(
1275
+ bsz, num_heads, tgt_len, src_len
1276
+ )
1277
+ attn_output_weights = attn_output_weights.masked_fill(
1278
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
1279
+ float("-inf"),
1280
+ )
1281
+ attn_output_weights = attn_output_weights.view(
1282
+ bsz * num_heads, tgt_len, src_len
1283
+ )
1284
+
1285
+ attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
1286
+
1287
+ # If we are using dynamic_chunk_training and setting a limited
1288
+ # num_left_chunks, the attention may only see the padding values which
1289
+ # will also be masked out by `key_padding_mask`, at this circumstances,
1290
+ # the whole column of `attn_output_weights` will be `-inf`
1291
+ # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
1292
+ # positions to avoid invalid loss value below.
1293
+ if (
1294
+ attn_mask is not None
1295
+ and attn_mask.dtype == torch.bool
1296
+ and key_padding_mask is not None
1297
+ ):
1298
+ if attn_mask.size(0) != 1:
1299
+ attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
1300
+ combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
1301
+ else:
1302
+ # attn_mask.shape == (1, tgt_len, src_len)
1303
+ combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
1304
+ 1
1305
+ ).unsqueeze(2)
1306
+
1307
+ attn_output_weights = attn_output_weights.view(
1308
+ bsz, num_heads, tgt_len, src_len
1309
+ )
1310
+ attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
1311
+ attn_output_weights = attn_output_weights.view(
1312
+ bsz * num_heads, tgt_len, src_len
1313
+ )
1314
+
1315
+ attn_output_weights = nn.functional.dropout(
1316
+ attn_output_weights, p=dropout_p, training=training
1317
+ )
1318
+
1319
+ attn_output = torch.bmm(attn_output_weights, v)
1320
+
1321
+ if not is_jit_tracing():
1322
+ assert list(attn_output.size()) == [
1323
+ bsz * num_heads,
1324
+ tgt_len,
1325
+ head_dim,
1326
+ ]
1327
+
1328
+ attn_output = (
1329
+ attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
1330
+ )
1331
+ attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
1332
+
1333
+ if need_weights:
1334
+ # average attention weights over heads
1335
+ attn_output_weights = attn_output_weights.view(
1336
+ bsz, num_heads, tgt_len, src_len
1337
+ )
1338
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
1339
+ else:
1340
+ return attn_output, None
1341
+
1342
+
1343
+ class ConvolutionModule(nn.Module):
1344
+ """ConvolutionModule in Conformer model.
1345
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
1346
+
1347
+ Args:
1348
+ channels (int): The number of channels of conv layers.
1349
+ kernel_size (int): Kernerl size of conv layers.
1350
+ bias (bool): Whether to use bias in conv layers (default=True).
1351
+ causal (bool): Whether to use causal convolution.
1352
+ """
1353
+
1354
+ def __init__(
1355
+ self,
1356
+ channels: int,
1357
+ kernel_size: int,
1358
+ bias: bool = True,
1359
+ causal: bool = False,
1360
+ ) -> None:
1361
+ """Construct an ConvolutionModule object."""
1362
+ super(ConvolutionModule, self).__init__()
1363
+ # kernerl_size should be a odd number for 'SAME' padding
1364
+ assert (kernel_size - 1) % 2 == 0
1365
+ self.causal = causal
1366
+
1367
+ self.pointwise_conv1 = ScaledConv1d(
1368
+ channels,
1369
+ 2 * channels,
1370
+ kernel_size=1,
1371
+ stride=1,
1372
+ padding=0,
1373
+ bias=bias,
1374
+ )
1375
+
1376
+ # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
1377
+ # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
1378
+ # but sometimes, for some reason, for layer 0 the rms ends up being very large,
1379
+ # between 50 and 100 for different channels. This will cause very peaky and
1380
+ # sparse derivatives for the sigmoid gating function, which will tend to make
1381
+ # the loss function not learn effectively. (for most layers the average absolute values
1382
+ # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
1383
+ # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
1384
+ # layers, which likely breaks down as 0.5 for the "linear" half and
1385
+ # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
1386
+ # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
1387
+ # it will be in a better position to start learning something, i.e. to latch onto
1388
+ # the correct range.
1389
+ self.deriv_balancer1 = ActivationBalancer(
1390
+ channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
1391
+ )
1392
+
1393
+ self.lorder = kernel_size - 1
1394
+ padding = (kernel_size - 1) // 2
1395
+ if self.causal:
1396
+ padding = 0
1397
+
1398
+ self.depthwise_conv = ScaledConv1d(
1399
+ channels,
1400
+ channels,
1401
+ kernel_size,
1402
+ stride=1,
1403
+ padding=padding,
1404
+ groups=channels,
1405
+ bias=bias,
1406
+ )
1407
+
1408
+ self.deriv_balancer2 = ActivationBalancer(
1409
+ channel_dim=1, min_positive=0.05, max_positive=1.0
1410
+ )
1411
+
1412
+ self.activation = DoubleSwish()
1413
+
1414
+ self.pointwise_conv2 = ScaledConv1d(
1415
+ channels,
1416
+ channels,
1417
+ kernel_size=1,
1418
+ stride=1,
1419
+ padding=0,
1420
+ bias=bias,
1421
+ initial_scale=0.25,
1422
+ )
1423
+
1424
+ def forward(
1425
+ self,
1426
+ x: Tensor,
1427
+ cache: Optional[Tensor] = None,
1428
+ right_context: int = 0,
1429
+ src_key_padding_mask: Optional[Tensor] = None,
1430
+ ) -> Tuple[Tensor, Tensor]:
1431
+ """Compute convolution module.
1432
+
1433
+ Args:
1434
+ x: Input tensor (#time, batch, channels).
1435
+ cache: The cache of depthwise_conv, only used in real streaming
1436
+ decoding.
1437
+ right_context:
1438
+ How many future frames the attention can see in current chunk.
1439
+ Note: It's not that each individual frame has `right_context` frames
1440
+ src_key_padding_mask: the mask for the src keys per batch (optional).
1441
+ of right context, some have more.
1442
+
1443
+ Returns:
1444
+ If cache is None return the output tensor (#time, batch, channels).
1445
+ If cache is not None, return a tuple of Tensor, the first one is
1446
+ the output tensor (#time, batch, channels), the second one is the
1447
+ new cache for next chunk (#kernel_size - 1, batch, channels).
1448
+
1449
+ """
1450
+ # exchange the temporal dimension and the feature dimension
1451
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
1452
+
1453
+ # GLU mechanism
1454
+ x = self.pointwise_conv1(x) # (batch, 2*channels, time)
1455
+
1456
+ x = self.deriv_balancer1(x)
1457
+ x = nn.functional.glu(x, dim=1) # (batch, channels, time)
1458
+
1459
+ # 1D Depthwise Conv
1460
+ if src_key_padding_mask is not None:
1461
+ x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
1462
+ if self.causal and self.lorder > 0:
1463
+ if cache is None:
1464
+ # Make depthwise_conv causal by
1465
+ # manualy padding self.lorder zeros to the left
1466
+ x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
1467
+ else:
1468
+ assert not self.training, "Cache should be None in training time"
1469
+ assert cache.size(0) == self.lorder
1470
+ x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
1471
+ if right_context > 0:
1472
+ cache = x.permute(2, 0, 1)[
1473
+ -(self.lorder + right_context) : (-right_context), # noqa
1474
+ ...,
1475
+ ]
1476
+ else:
1477
+ cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
1478
+ x = self.depthwise_conv(x)
1479
+
1480
+ x = self.deriv_balancer2(x)
1481
+ x = self.activation(x)
1482
+
1483
+ x = self.pointwise_conv2(x) # (batch, channel, time)
1484
+
1485
+ # torch.jit.script requires return types be the same as annotated above
1486
+ if cache is None:
1487
+ cache = torch.empty(0)
1488
+
1489
+ return x.permute(2, 0, 1), cache
1490
+
1491
+
1492
+ class Conv2dSubsampling(nn.Module):
1493
+ """Convolutional 2D subsampling (to 1/4 length).
1494
+
1495
+ Convert an input of shape (N, T, idim) to an output
1496
+ with shape (N, T', odim), where
1497
+ T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
1498
+
1499
+ It is based on
1500
+ https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
1501
+ """
1502
+
1503
+ def __init__(
1504
+ self,
1505
+ in_channels: int,
1506
+ out_channels: int,
1507
+ layer1_channels: int = 8,
1508
+ layer2_channels: int = 32,
1509
+ layer3_channels: int = 128,
1510
+ ) -> None:
1511
+ """
1512
+ Args:
1513
+ in_channels:
1514
+ Number of channels in. The input shape is (N, T, in_channels).
1515
+ Caution: It requires: T >=7, in_channels >=7
1516
+ out_channels
1517
+ Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
1518
+ layer1_channels:
1519
+ Number of channels in layer1
1520
+ layer1_channels:
1521
+ Number of channels in layer2
1522
+ """
1523
+ assert in_channels >= 7
1524
+ super().__init__()
1525
+
1526
+ self.conv = nn.Sequential(
1527
+ ScaledConv2d(
1528
+ in_channels=1,
1529
+ out_channels=layer1_channels,
1530
+ kernel_size=3,
1531
+ padding=1,
1532
+ ),
1533
+ ActivationBalancer(channel_dim=1),
1534
+ DoubleSwish(),
1535
+ ScaledConv2d(
1536
+ in_channels=layer1_channels,
1537
+ out_channels=layer2_channels,
1538
+ kernel_size=3,
1539
+ stride=2,
1540
+ ),
1541
+ ActivationBalancer(channel_dim=1),
1542
+ DoubleSwish(),
1543
+ ScaledConv2d(
1544
+ in_channels=layer2_channels,
1545
+ out_channels=layer3_channels,
1546
+ kernel_size=3,
1547
+ stride=2,
1548
+ ),
1549
+ ActivationBalancer(channel_dim=1),
1550
+ DoubleSwish(),
1551
+ )
1552
+ self.out = ScaledLinear(
1553
+ layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
1554
+ )
1555
+ # set learn_eps=False because out_norm is preceded by `out`, and `out`
1556
+ # itself has learned scale, so the extra degree of freedom is not
1557
+ # needed.
1558
+ self.out_norm = BasicNorm(out_channels, learn_eps=False)
1559
+ # constrain median of output to be close to zero.
1560
+ self.out_balancer = ActivationBalancer(
1561
+ channel_dim=-1, min_positive=0.45, max_positive=0.55
1562
+ )
1563
+
1564
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1565
+ """Subsample x.
1566
+
1567
+ Args:
1568
+ x:
1569
+ Its shape is (N, T, idim).
1570
+
1571
+ Returns:
1572
+ Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
1573
+ """
1574
+ # On entry, x is (N, T, idim)
1575
+ x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
1576
+ x = self.conv(x)
1577
+ # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
1578
+ b, c, t, f = x.size()
1579
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
1580
+ # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
1581
+ x = self.out_norm(x)
1582
+ x = self.out_balancer(x)
1583
+ return x
1584
+
1585
+
1586
+ if __name__ == "__main__":
1587
+ torch.set_num_threads(1)
1588
+ torch.set_num_interop_threads(1)
1589
+ feature_dim = 50
1590
+ c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
1591
+ batch_size = 5
1592
+ seq_len = 20
1593
+ # Just make sure the forward pass runs.
1594
+ f = c(
1595
+ torch.randn(batch_size, seq_len, feature_dim),
1596
+ torch.full((batch_size,), seq_len, dtype=torch.int64),
1597
+ warmup=0.5,
1598
+ )
err2020/conformer_ctc3/decode.py ADDED
@@ -0,0 +1,1052 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
4
+ # Zengwei Yao)
5
+ #
6
+ # See ../../../../LICENSE for clarification regarding multiple authors
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """
20
+ Usage:
21
+ (1) decode in non-streaming mode (take ctc-decoding as an example)
22
+ ./conformer_ctc3/decode.py \
23
+ --epoch 30 \
24
+ --avg 15 \
25
+ --exp-dir ./conformer_ctc3/exp \
26
+ --max-duration 600 \
27
+ --decoding-method ctc-decoding
28
+
29
+ (2) decode in streaming mode (take ctc-decoding as an example)
30
+ ./conformer_ctc3/decode.py \
31
+ --epoch 30 \
32
+ --avg 15 \
33
+ --simulate-streaming 1 \
34
+ --causal-convolution 1 \
35
+ --decode-chunk-size 16 \
36
+ --left-context 64 \
37
+ --exp-dir ./conformer_ctc3/exp \
38
+ --max-duration 600 \
39
+ --decoding-method ctc-decoding
40
+
41
+ To evaluate symbol delay, you should:
42
+ (1) Generate cuts with word-time alignments:
43
+ ./add_alignments.sh
44
+ (2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
45
+ For example:
46
+ ./conformer_ctc3/decode.py \
47
+ --epoch 30 \
48
+ --avg 15 \
49
+ --exp-dir ./conformer_ctc3/exp \
50
+ --max-duration 600 \
51
+ --decoding-method ctc-decoding \
52
+ --simulate-streaming 1 \
53
+ --causal-convolution 1 \
54
+ --decode-chunk-size 16 \
55
+ --left-context 64 \
56
+ --manifest-dir data/fbank_ali
57
+ Note: It supports calculating symbol delay with following decoding methods:
58
+ - ctc-decoding
59
+ - 1best
60
+ """
61
+
62
+
63
+ import argparse
64
+ import logging
65
+ import math
66
+ from collections import defaultdict
67
+ from pathlib import Path
68
+ from typing import Dict, List, Optional, Tuple
69
+
70
+ import k2
71
+ import sentencepiece as spm
72
+ import torch
73
+ import torch.nn as nn
74
+ from asr_datamodule import LibriSpeechAsrDataModule
75
+ from train import add_model_arguments, get_ctc_model, get_params
76
+
77
+ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
78
+ from icefall.checkpoint import (
79
+ average_checkpoints,
80
+ average_checkpoints_with_averaged_model,
81
+ find_checkpoints,
82
+ load_checkpoint,
83
+ )
84
+ from icefall.decode import (
85
+ get_lattice,
86
+ nbest_decoding,
87
+ nbest_oracle,
88
+ one_best_decoding,
89
+ rescore_with_n_best_list,
90
+ rescore_with_whole_lattice,
91
+ )
92
+ from icefall.lexicon import Lexicon
93
+ from icefall.utils import (
94
+ AttributeDict,
95
+ convert_timestamp,
96
+ get_texts,
97
+ make_pad_mask,
98
+ parse_bpe_start_end_pairs,
99
+ parse_fsa_timestamps_and_texts,
100
+ setup_logger,
101
+ store_transcripts_and_timestamps,
102
+ str2bool,
103
+ write_error_stats_with_timestamps,
104
+ )
105
+
106
+ LOG_EPS = math.log(1e-10)
107
+
108
+
109
+ def get_parser():
110
+ parser = argparse.ArgumentParser(
111
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--epoch",
116
+ type=int,
117
+ default=30,
118
+ help="""It specifies the checkpoint to use for decoding.
119
+ Note: Epoch counts from 1.
120
+ You can specify --avg to use more checkpoints for model averaging.""",
121
+ )
122
+
123
+ parser.add_argument(
124
+ "--iter",
125
+ type=int,
126
+ default=0,
127
+ help="""If positive, --epoch is ignored and it
128
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
129
+ You can specify --avg to use more checkpoints for model averaging.
130
+ """,
131
+ )
132
+
133
+ parser.add_argument(
134
+ "--avg",
135
+ type=int,
136
+ default=15,
137
+ help="Number of checkpoints to average. Automatically select "
138
+ "consecutive checkpoints before the checkpoint specified by "
139
+ "'--epoch' and '--iter'",
140
+ )
141
+
142
+ parser.add_argument(
143
+ "--use-averaged-model",
144
+ type=str2bool,
145
+ default=True,
146
+ help="Whether to load averaged model. Currently it only supports "
147
+ "using --epoch. If True, it would decode with the averaged model "
148
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
149
+ "Actually only the models with epoch number of `epoch-avg` and "
150
+ "`epoch` are loaded for averaging. ",
151
+ )
152
+
153
+ parser.add_argument(
154
+ "--exp-dir",
155
+ type=str,
156
+ default="pruned_transducer_stateless4/exp",
157
+ help="The experiment dir",
158
+ )
159
+
160
+ parser.add_argument(
161
+ "--lang-dir",
162
+ type=Path,
163
+ default="data/lang_bpe_500",
164
+ help="The lang dir containing word table and LG graph",
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--decoding-method",
169
+ type=str,
170
+ default="ctc-decoding",
171
+ help="""Decoding method.
172
+ Supported values are:
173
+ - (0) ctc-greedy-search. It uses a sentence piece model,
174
+ i.e., lang_dir/bpe.model, to convert word pieces to words.
175
+ It needs neither a lexicon nor an n-gram LM.
176
+ - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
177
+ model, i.e., lang_dir/bpe.model, to convert word pieces to words.
178
+ It needs neither a lexicon nor an n-gram LM.
179
+ - (2) 1best. Extract the best path from the decoding lattice as the
180
+ decoding result.
181
+ - (3) nbest. Extract n paths from the decoding lattice; the path
182
+ with the highest score is the decoding result.
183
+ - (4) nbest-rescoring. Extract n paths from the decoding lattice,
184
+ rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
185
+ the highest score is the decoding result.
186
+ - (5) whole-lattice-rescoring. Rescore the decoding lattice with an
187
+ n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
188
+ is the decoding result.
189
+ you have trained an RNN LM using ./rnn_lm/train.py
190
+ - (6) nbest-oracle. Its WER is the lower bound of any n-best
191
+ rescoring method can achieve. Useful for debugging n-best
192
+ rescoring method.
193
+ """,
194
+ )
195
+
196
+ parser.add_argument(
197
+ "--num-paths",
198
+ type=int,
199
+ default=100,
200
+ help="""Number of paths for n-best based decoding method.
201
+ Used only when "method" is one of the following values:
202
+ nbest, nbest-rescoring, and nbest-oracle
203
+ """,
204
+ )
205
+
206
+ parser.add_argument(
207
+ "--nbest-scale",
208
+ type=float,
209
+ default=0.5,
210
+ help="""The scale to be applied to `lattice.scores`.
211
+ It's needed if you use any kinds of n-best based rescoring.
212
+ Used only when "method" is one of the following values:
213
+ nbest, nbest-rescoring, and nbest-oracle
214
+ A smaller value results in more unique paths.
215
+ """,
216
+ )
217
+
218
+ parser.add_argument(
219
+ "--lm-dir",
220
+ type=str,
221
+ default="data/lm",
222
+ help="""The n-gram LM dir.
223
+ It should contain either G_4_gram.pt or G_4_gram.fst.txt
224
+ """,
225
+ )
226
+
227
+ parser.add_argument(
228
+ "--simulate-streaming",
229
+ type=str2bool,
230
+ default=False,
231
+ help="""Whether to simulate streaming in decoding, this is a good way to
232
+ test a streaming model.
233
+ """,
234
+ )
235
+
236
+ parser.add_argument(
237
+ "--decode-chunk-size",
238
+ type=int,
239
+ default=16,
240
+ help="The chunk size for decoding (in frames after subsampling)",
241
+ )
242
+
243
+ parser.add_argument(
244
+ "--left-context",
245
+ type=int,
246
+ default=64,
247
+ help="left context can be seen during decoding (in frames after subsampling)",
248
+ )
249
+
250
+ parser.add_argument(
251
+ "--hlg-scale",
252
+ type=float,
253
+ default=0.8,
254
+ help="""The scale to be applied to `hlg.scores`.
255
+ """,
256
+ )
257
+
258
+ add_model_arguments(parser)
259
+
260
+ return parser
261
+
262
+
263
+ def get_decoding_params() -> AttributeDict:
264
+ """Parameters for decoding."""
265
+ params = AttributeDict(
266
+ {
267
+ "frame_shift_ms": 10,
268
+ "search_beam": 20,
269
+ "output_beam": 8,
270
+ "min_active_states": 30,
271
+ "max_active_states": 10000,
272
+ "use_double_scores": True,
273
+ }
274
+ )
275
+ return params
276
+
277
+
278
+ def ctc_greedy_search(
279
+ ctc_probs: torch.Tensor,
280
+ nnet_output_lens: torch.Tensor,
281
+ sp: spm.SentencePieceProcessor,
282
+ subsampling_factor: int = 4,
283
+ frame_shift_ms: float = 10,
284
+ ) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
285
+ """Apply CTC greedy search
286
+ Args:
287
+ ctc_probs (torch.Tensor):
288
+ (batch, max_len, feat_dim)
289
+ nnet_output_lens (torch.Tensor):
290
+ (batch, )
291
+ sp:
292
+ The BPE model.
293
+ subsampling_factor:
294
+ The subsampling factor of the model.
295
+ frame_shift_ms:
296
+ Frame shift in milliseconds between two contiguous frames.
297
+
298
+ Returns:
299
+ utt_time_pairs:
300
+ A list of pair list. utt_time_pairs[i] is a list of
301
+ (start-time, end-time) pairs for each word in
302
+ utterance-i.
303
+ utt_words:
304
+ A list of str list. utt_words[i] is a word list of utterence-i.
305
+ """
306
+ topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
307
+ topk_index = topk_index.squeeze(2) # (B, maxlen)
308
+ mask = make_pad_mask(nnet_output_lens)
309
+ topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen)
310
+ hyps = [hyp.tolist() for hyp in topk_index]
311
+
312
+ def get_first_tokens(tokens: List[int]) -> List[bool]:
313
+ is_first_token = []
314
+ first_tokens = []
315
+ for t in range(len(tokens)):
316
+ if tokens[t] != 0 and (t == 0 or tokens[t - 1] != tokens[t]):
317
+ is_first_token.append(True)
318
+ first_tokens.append(tokens[t])
319
+ else:
320
+ is_first_token.append(False)
321
+ return first_tokens, is_first_token
322
+
323
+ utt_time_pairs = []
324
+ utt_words = []
325
+ for utt in range(len(hyps)):
326
+ first_tokens, is_first_token = get_first_tokens(hyps[utt])
327
+ all_tokens = sp.id_to_piece(hyps[utt])
328
+ index_pairs = parse_bpe_start_end_pairs(all_tokens, is_first_token)
329
+ words = sp.decode(first_tokens).split()
330
+ assert len(index_pairs) == len(words), (
331
+ len(index_pairs),
332
+ len(words),
333
+ all_tokens,
334
+ )
335
+ start = convert_timestamp(
336
+ frames=[i[0] for i in index_pairs],
337
+ subsampling_factor=subsampling_factor,
338
+ frame_shift_ms=frame_shift_ms,
339
+ )
340
+ end = convert_timestamp(
341
+ # The duration in frames is (end_frame_index - start_frame_index + 1)
342
+ frames=[i[1] + 1 for i in index_pairs],
343
+ subsampling_factor=subsampling_factor,
344
+ frame_shift_ms=frame_shift_ms,
345
+ )
346
+ utt_time_pairs.append(list(zip(start, end)))
347
+ utt_words.append(words)
348
+
349
+ return utt_time_pairs, utt_words
350
+
351
+
352
+ def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]:
353
+ # modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
354
+ new_hyp: List[int] = []
355
+ time: List[Tuple[int, int]] = []
356
+ cur = 0
357
+ start, end = -1, -1
358
+ while cur < len(hyp):
359
+ if hyp[cur] != 0:
360
+ new_hyp.append(hyp[cur])
361
+ start = cur
362
+ prev = cur
363
+ while cur < len(hyp) and hyp[cur] == hyp[prev]:
364
+ if start != -1:
365
+ end = cur
366
+ cur += 1
367
+ if start != -1 and end != -1:
368
+ time.append((start, end))
369
+ start, end = -1, -1
370
+ return new_hyp, time
371
+
372
+
373
+ def decode_one_batch(
374
+ params: AttributeDict,
375
+ model: nn.Module,
376
+ HLG: Optional[k2.Fsa],
377
+ H: Optional[k2.Fsa],
378
+ bpe_model: Optional[spm.SentencePieceProcessor],
379
+ batch: dict,
380
+ word_table: k2.SymbolTable,
381
+ sos_id: int,
382
+ eos_id: int,
383
+ G: Optional[k2.Fsa] = None,
384
+ ) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
385
+ """Decode one batch and return the result in a dict. The dict has the
386
+ following format:
387
+ - key: It indicates the setting used for decoding. For example,
388
+ if no rescoring is used, the key is the string `no_rescore`.
389
+ If LM rescoring is used, the key is the string `lm_scale_xxx`,
390
+ where `xxx` is the value of `lm_scale`. An example key is
391
+ `lm_scale_0.7`
392
+ - value: It contains the decoding result. `len(value)` equals to
393
+ batch size. `value[i]` is the decoding result for the i-th
394
+ utterance in the given batch.
395
+
396
+ Args:
397
+ params:
398
+ It's the return value of :func:`get_params`.
399
+
400
+ - params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
401
+ - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
402
+ - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
403
+ - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
404
+ rescoring.
405
+
406
+ model:
407
+ The neural model.
408
+ HLG:
409
+ The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
410
+ H:
411
+ The ctc topo. Used only when params.decoding_method is ctc-decoding.
412
+ bpe_model:
413
+ The BPE model. Used only when params.decoding_method is ctc-decoding.
414
+ batch:
415
+ It is the return value from iterating
416
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
417
+ for the format of the `batch`.
418
+ word_table:
419
+ The word symbol table.
420
+ sos_id:
421
+ The token ID of the SOS.
422
+ eos_id:
423
+ The token ID of the EOS.
424
+ G:
425
+ An LM. It is not None when params.decoding_method is "nbest-rescoring"
426
+ or "whole-lattice-rescoring". In general, the G in HLG
427
+ is a 3-gram LM, while this G is a 4-gram LM.
428
+ Returns:
429
+ Return the decoding result. See above description for the format of
430
+ the returned dict. Note: If it decodes to nothing, then return None.
431
+ """
432
+ if HLG is not None:
433
+ device = HLG.device
434
+ else:
435
+ device = H.device
436
+ feature = batch["inputs"]
437
+ assert feature.ndim == 3
438
+ feature = feature.to(device)
439
+ # at entry, feature is (N, T, C)
440
+
441
+ supervisions = batch["supervisions"]
442
+ feature_lens = supervisions["num_frames"].to(device)
443
+
444
+ if params.simulate_streaming:
445
+ feature_lens += params.left_context
446
+ feature = torch.nn.functional.pad(
447
+ feature,
448
+ pad=(0, 0, 0, params.left_context),
449
+ value=LOG_EPS,
450
+ )
451
+ encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
452
+ x=feature,
453
+ x_lens=feature_lens,
454
+ chunk_size=params.decode_chunk_size,
455
+ left_context=params.left_context,
456
+ simulate_streaming=True,
457
+ )
458
+ else:
459
+ encoder_out, encoder_out_lens = model.encoder(feature, feature_lens)
460
+
461
+ nnet_output = model.get_ctc_output(encoder_out)
462
+ # nnet_output is (N, T, C)
463
+
464
+ if params.decoding_method == "ctc-greedy-search":
465
+ timestamps, hyps = ctc_greedy_search(
466
+ ctc_probs=nnet_output,
467
+ nnet_output_lens=encoder_out_lens,
468
+ sp=bpe_model,
469
+ subsampling_factor=params.subsampling_factor,
470
+ frame_shift_ms=params.frame_shift_ms,
471
+ )
472
+ key = "ctc-greedy-search"
473
+ return {key: (hyps, timestamps)}
474
+
475
+ supervision_segments = torch.stack(
476
+ (
477
+ supervisions["sequence_idx"],
478
+ supervisions["start_frame"] // params.subsampling_factor,
479
+ encoder_out_lens.cpu(),
480
+ ),
481
+ 1,
482
+ ).to(torch.int32)
483
+
484
+ if H is None:
485
+ assert HLG is not None
486
+ decoding_graph = HLG
487
+ else:
488
+ assert HLG is None
489
+ assert bpe_model is not None
490
+ decoding_graph = H
491
+
492
+ lattice = get_lattice(
493
+ nnet_output=nnet_output,
494
+ decoding_graph=decoding_graph,
495
+ supervision_segments=supervision_segments,
496
+ search_beam=params.search_beam,
497
+ output_beam=params.output_beam,
498
+ min_active_states=params.min_active_states,
499
+ max_active_states=params.max_active_states,
500
+ subsampling_factor=params.subsampling_factor,
501
+ )
502
+
503
+ if params.decoding_method == "ctc-decoding":
504
+ best_path = one_best_decoding(
505
+ lattice=lattice, use_double_scores=params.use_double_scores
506
+ )
507
+ timestamps, hyps = parse_fsa_timestamps_and_texts(
508
+ best_paths=best_path,
509
+ sp=bpe_model,
510
+ subsampling_factor=params.subsampling_factor,
511
+ frame_shift_ms=params.frame_shift_ms,
512
+ )
513
+ key = "ctc-decoding"
514
+ return {key: (hyps, timestamps)}
515
+
516
+ if params.decoding_method == "nbest-oracle":
517
+ # Note: You can also pass rescored lattices to it.
518
+ # We choose the HLG decoded lattice for speed reasons
519
+ # as HLG decoding is faster and the oracle WER
520
+ # is only slightly worse than that of rescored lattices.
521
+ best_path = nbest_oracle(
522
+ lattice=lattice,
523
+ num_paths=params.num_paths,
524
+ ref_texts=supervisions["text"],
525
+ word_table=word_table,
526
+ nbest_scale=params.nbest_scale,
527
+ oov="<UNK>",
528
+ )
529
+ hyps = get_texts(best_path)
530
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
531
+ timestamps = [[] for _ in range(len(hyps))]
532
+ key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}_hlg_scale_{params.hlg_scale}" # noqa
533
+ return {key: (hyps, timestamps)}
534
+
535
+ if params.decoding_method in ["1best", "nbest"]:
536
+ if params.decoding_method == "1best":
537
+ best_path = one_best_decoding(
538
+ lattice=lattice, use_double_scores=params.use_double_scores
539
+ )
540
+ key = f"no_rescore_hlg_scale_{params.hlg_scale}"
541
+ timestamps, hyps = parse_fsa_timestamps_and_texts(
542
+ best_paths=best_path,
543
+ word_table=word_table,
544
+ subsampling_factor=params.subsampling_factor,
545
+ frame_shift_ms=params.frame_shift_ms,
546
+ )
547
+ else:
548
+ best_path = nbest_decoding(
549
+ lattice=lattice,
550
+ num_paths=params.num_paths,
551
+ use_double_scores=params.use_double_scores,
552
+ nbest_scale=params.nbest_scale,
553
+ )
554
+ key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}-hlg-scale-{params.hlg_scale}" # noqa
555
+ hyps = get_texts(best_path)
556
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
557
+ timestamps = [[] for _ in range(len(hyps))]
558
+ return {key: (hyps, timestamps)}
559
+
560
+ assert params.decoding_method in [
561
+ "nbest-rescoring",
562
+ "whole-lattice-rescoring",
563
+ ]
564
+
565
+ lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
566
+ lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
567
+ lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
568
+
569
+ if params.decoding_method == "nbest-rescoring":
570
+ best_path_dict = rescore_with_n_best_list(
571
+ lattice=lattice,
572
+ G=G,
573
+ num_paths=params.num_paths,
574
+ lm_scale_list=lm_scale_list,
575
+ nbest_scale=params.nbest_scale,
576
+ )
577
+ elif params.decoding_method == "whole-lattice-rescoring":
578
+ best_path_dict = rescore_with_whole_lattice(
579
+ lattice=lattice,
580
+ G_with_epsilon_loops=G,
581
+ lm_scale_list=lm_scale_list,
582
+ )
583
+ else:
584
+ assert False, f"Unsupported decoding method: {params.decoding_method}"
585
+
586
+ ans = dict()
587
+ if best_path_dict is not None:
588
+ for lm_scale_str, best_path in best_path_dict.items():
589
+ hyps = get_texts(best_path)
590
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
591
+ timestamps = [[] for _ in range(len(hyps))]
592
+ ans[lm_scale_str] = (hyps, timestamps)
593
+ else:
594
+ ans = None
595
+ return ans
596
+
597
+
598
+ def decode_dataset(
599
+ dl: torch.utils.data.DataLoader,
600
+ params: AttributeDict,
601
+ model: nn.Module,
602
+ HLG: Optional[k2.Fsa],
603
+ H: Optional[k2.Fsa],
604
+ bpe_model: Optional[spm.SentencePieceProcessor],
605
+ word_table: k2.SymbolTable,
606
+ sos_id: int,
607
+ eos_id: int,
608
+ G: Optional[k2.Fsa] = None,
609
+ ) -> Dict[
610
+ str,
611
+ List[
612
+ Tuple[
613
+ str,
614
+ List[str],
615
+ List[str],
616
+ List[Tuple[float, float]],
617
+ List[Tuple[float, float]],
618
+ ]
619
+ ],
620
+ ]:
621
+ """Decode dataset.
622
+
623
+ Args:
624
+ dl:
625
+ PyTorch's dataloader containing the dataset to decode.
626
+ params:
627
+ It is returned by :func:`get_params`.
628
+ model:
629
+ The neural model.
630
+ HLG:
631
+ The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
632
+ H:
633
+ The ctc topo. Used only when params.decoding_method is ctc-decoding.
634
+ bpe_model:
635
+ The BPE model. Used only when params.decoding_method is ctc-decoding.
636
+ word_table:
637
+ It is the word symbol table.
638
+ sos_id:
639
+ The token ID for SOS.
640
+ eos_id:
641
+ The token ID for EOS.
642
+ G:
643
+ An LM. It is not None when params.decoding_method is "nbest-rescoring"
644
+ or "whole-lattice-rescoring". In general, the G in HLG
645
+ is a 3-gram LM, while this G is a 4-gram LM.
646
+ Returns:
647
+ Return a dict, whose key may be "no-rescore" if no LM rescoring
648
+ is used, or it may be "lm_scale_0.7" if LM rescoring is used.
649
+ Its value is a list of tuples. Each tuple contains two elements:
650
+ The first is the reference transcript, and the second is the
651
+ predicted result.
652
+ """
653
+ num_cuts = 0
654
+
655
+ try:
656
+ num_batches = len(dl)
657
+ except TypeError:
658
+ num_batches = "?"
659
+
660
+ results = defaultdict(list)
661
+ for batch_idx, batch in enumerate(dl):
662
+ texts = batch["supervisions"]["text"]
663
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
664
+
665
+ timestamps_ref = []
666
+ for cut in batch["supervisions"]["cut"]:
667
+ for s in cut.supervisions:
668
+ time = []
669
+ if s.alignment is not None and "word" in s.alignment:
670
+ time = [
671
+ (aliword.start, aliword.end)
672
+ for aliword in s.alignment["word"]
673
+ if aliword.symbol != ""
674
+ ]
675
+ timestamps_ref.append(time)
676
+
677
+ hyps_dict = decode_one_batch(
678
+ params=params,
679
+ model=model,
680
+ HLG=HLG,
681
+ H=H,
682
+ bpe_model=bpe_model,
683
+ batch=batch,
684
+ word_table=word_table,
685
+ G=G,
686
+ sos_id=sos_id,
687
+ eos_id=eos_id,
688
+ )
689
+
690
+ for name, (hyps, timestamps_hyp) in hyps_dict.items():
691
+ this_batch = []
692
+ assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
693
+ timestamps_ref
694
+ )
695
+ for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
696
+ cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
697
+ ):
698
+ ref_words = ref_text.split()
699
+ this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp))
700
+
701
+ results[name].extend(this_batch)
702
+
703
+ num_cuts += len(texts)
704
+
705
+ if batch_idx % 100 == 0:
706
+ batch_str = f"{batch_idx}/{num_batches}"
707
+
708
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
709
+ return results
710
+
711
+
712
+ def save_results(
713
+ params: AttributeDict,
714
+ test_set_name: str,
715
+ results_dict: Dict[
716
+ str,
717
+ List[
718
+ Tuple[
719
+ List[str],
720
+ List[str],
721
+ List[str],
722
+ List[Tuple[float, float]],
723
+ List[Tuple[float, float]],
724
+ ]
725
+ ],
726
+ ],
727
+ ):
728
+ test_set_wers = dict()
729
+ test_set_delays = dict()
730
+ for key, results in results_dict.items():
731
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
732
+ results = sorted(results)
733
+ store_transcripts_and_timestamps(filename=recog_path, texts=results)
734
+ logging.info(f"The transcripts are stored in {recog_path}")
735
+
736
+ # The following prints out WERs, per-word error statistics and aligned
737
+ # ref/hyp pairs.
738
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
739
+ with open(errs_filename, "w") as f:
740
+ wer, mean_delay, var_delay = write_error_stats_with_timestamps(
741
+ f,
742
+ f"{test_set_name}-{key}",
743
+ results,
744
+ enable_log=True,
745
+ with_end_time=True,
746
+ )
747
+ test_set_wers[key] = wer
748
+ test_set_delays[key] = (mean_delay, var_delay)
749
+
750
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
751
+
752
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
753
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
754
+ with open(errs_info, "w") as f:
755
+ print("settings\tWER", file=f)
756
+ for key, val in test_set_wers:
757
+ print("{}\t{}".format(key, val), file=f)
758
+
759
+ # sort according to the mean start symbol delay
760
+ test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0])
761
+ delays_info = (
762
+ params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt"
763
+ )
764
+ with open(delays_info, "w") as f:
765
+ print("settings\t(start, end) symbol-delay (s) (start, end)", file=f)
766
+ for key, val in test_set_delays:
767
+ print(
768
+ "{}\tmean: {}, variance: {}".format(key, val[0], val[1]),
769
+ file=f,
770
+ )
771
+
772
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
773
+ note = "\tbest for {}".format(test_set_name)
774
+ for key, val in test_set_wers:
775
+ s += "{}\t{}{}\n".format(key, val, note)
776
+ note = ""
777
+ logging.info(s)
778
+
779
+ s = "\nFor {}, (start, end) symbol-delay (s) of different settings are:\n".format(
780
+ test_set_name
781
+ )
782
+ note = "\tbest for {}".format(test_set_name)
783
+ for key, val in test_set_delays:
784
+ s += "{}\tmean: {}, variance: {}{}\n".format(key, val[0], val[1], note)
785
+ note = ""
786
+ logging.info(s)
787
+
788
+
789
+ @torch.no_grad()
790
+ def main():
791
+ parser = get_parser()
792
+ LibriSpeechAsrDataModule.add_arguments(parser)
793
+ args = parser.parse_args()
794
+ args.exp_dir = Path(args.exp_dir)
795
+ args.lang_dir = Path(args.lang_dir)
796
+ args.lm_dir = Path(args.lm_dir)
797
+
798
+ params = get_params()
799
+ # add decoding params
800
+ params.update(get_decoding_params())
801
+ params.update(vars(args))
802
+
803
+ assert params.decoding_method in (
804
+ "ctc-greedy-search",
805
+ "ctc-decoding",
806
+ "1best",
807
+ "nbest",
808
+ "nbest-rescoring",
809
+ "whole-lattice-rescoring",
810
+ "nbest-oracle",
811
+ )
812
+ params.res_dir = params.exp_dir / params.decoding_method
813
+
814
+ if params.iter > 0:
815
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
816
+ else:
817
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
818
+
819
+ if params.simulate_streaming:
820
+ params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
821
+ params.suffix += f"-left-context-{params.left_context}"
822
+
823
+ if params.simulate_streaming:
824
+ assert (
825
+ params.causal_convolution
826
+ ), "Decoding in streaming requires causal convolution"
827
+
828
+ if params.use_averaged_model:
829
+ params.suffix += "-use-averaged-model"
830
+
831
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
832
+ logging.info("Decoding started")
833
+
834
+ device = torch.device("cpu")
835
+ if torch.cuda.is_available():
836
+ device = torch.device("cuda", 0)
837
+
838
+ logging.info(f"Device: {device}")
839
+ logging.info(params)
840
+
841
+ lexicon = Lexicon(params.lang_dir)
842
+ max_token_id = max(lexicon.tokens)
843
+ num_classes = max_token_id + 1 # +1 for the blank
844
+
845
+ graph_compiler = BpeCtcTrainingGraphCompiler(
846
+ params.lang_dir,
847
+ device=device,
848
+ sos_token="<sos/eos>",
849
+ eos_token="<sos/eos>",
850
+ )
851
+ sos_id = graph_compiler.sos_id
852
+ eos_id = graph_compiler.eos_id
853
+
854
+ params.vocab_size = num_classes
855
+ params.sos_id = sos_id
856
+ params.eos_id = eos_id
857
+
858
+ if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]:
859
+ HLG = None
860
+ H = k2.ctc_topo(
861
+ max_token=max_token_id,
862
+ modified=False,
863
+ device=device,
864
+ )
865
+ bpe_model = spm.SentencePieceProcessor()
866
+ bpe_model.load(str(params.lang_dir / "bpe.model"))
867
+ else:
868
+ H = None
869
+ bpe_model = None
870
+ HLG = k2.Fsa.from_dict(
871
+ torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
872
+ )
873
+ assert HLG.requires_grad is False
874
+
875
+ HLG.scores *= params.hlg_scale
876
+ if not hasattr(HLG, "lm_scores"):
877
+ HLG.lm_scores = HLG.scores.clone()
878
+
879
+ if params.decoding_method in (
880
+ "nbest-rescoring",
881
+ "whole-lattice-rescoring",
882
+ ):
883
+ if not (params.lm_dir / "G_4_gram.pt").is_file():
884
+ logging.info("Loading G_4_gram.fst.txt")
885
+ logging.warning("It may take 8 minutes.")
886
+ with open(params.lm_dir / "G_4_gram.fst.txt") as f:
887
+ first_word_disambig_id = lexicon.word_table["#0"]
888
+
889
+ G = k2.Fsa.from_openfst(f.read(), acceptor=False)
890
+ # G.aux_labels is not needed in later computations, so
891
+ # remove it here.
892
+ del G.aux_labels
893
+ # CAUTION: The following line is crucial.
894
+ # Arcs entering the back-off state have label equal to #0.
895
+ # We have to change it to 0 here.
896
+ G.labels[G.labels >= first_word_disambig_id] = 0
897
+ # See https://github.com/k2-fsa/k2/issues/874
898
+ # for why we need to set G.properties to None
899
+ G.__dict__["_properties"] = None
900
+ G = k2.Fsa.from_fsas([G]).to(device)
901
+ G = k2.arc_sort(G)
902
+ # Save a dummy value so that it can be loaded in C++.
903
+ # See https://github.com/pytorch/pytorch/issues/67902
904
+ # for why we need to do this.
905
+ G.dummy = 1
906
+
907
+ torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
908
+ else:
909
+ logging.info("Loading pre-compiled G_4_gram.pt")
910
+ d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
911
+ G = k2.Fsa.from_dict(d)
912
+
913
+ if params.decoding_method == "whole-lattice-rescoring":
914
+ # Add epsilon self-loops to G as we will compose
915
+ # it with the whole lattice later
916
+ G = k2.add_epsilon_self_loops(G)
917
+ G = k2.arc_sort(G)
918
+ G = G.to(device)
919
+
920
+ # G.lm_scores is used to replace HLG.lm_scores during
921
+ # LM rescoring.
922
+ G.lm_scores = G.scores.clone()
923
+ else:
924
+ G = None
925
+
926
+ logging.info("About to create model")
927
+ model = get_ctc_model(params)
928
+
929
+ if not params.use_averaged_model:
930
+ if params.iter > 0:
931
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
932
+ : params.avg
933
+ ]
934
+ if len(filenames) == 0:
935
+ raise ValueError(
936
+ f"No checkpoints found for"
937
+ f" --iter {params.iter}, --avg {params.avg}"
938
+ )
939
+ elif len(filenames) < params.avg:
940
+ raise ValueError(
941
+ f"Not enough checkpoints ({len(filenames)}) found for"
942
+ f" --iter {params.iter}, --avg {params.avg}"
943
+ )
944
+ logging.info(f"averaging {filenames}")
945
+ model.to(device)
946
+ model.load_state_dict(average_checkpoints(filenames, device=device))
947
+ elif params.avg == 1:
948
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
949
+ else:
950
+ start = params.epoch - params.avg + 1
951
+ filenames = []
952
+ for i in range(start, params.epoch + 1):
953
+ if i >= 1:
954
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
955
+ logging.info(f"averaging {filenames}")
956
+ model.to(device)
957
+ model.load_state_dict(average_checkpoints(filenames, device=device))
958
+ else:
959
+ if params.iter > 0:
960
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
961
+ : params.avg + 1
962
+ ]
963
+ if len(filenames) == 0:
964
+ raise ValueError(
965
+ f"No checkpoints found for"
966
+ f" --iter {params.iter}, --avg {params.avg}"
967
+ )
968
+ elif len(filenames) < params.avg + 1:
969
+ raise ValueError(
970
+ f"Not enough checkpoints ({len(filenames)}) found for"
971
+ f" --iter {params.iter}, --avg {params.avg}"
972
+ )
973
+ filename_start = filenames[-1]
974
+ filename_end = filenames[0]
975
+ logging.info(
976
+ "Calculating the averaged model over iteration checkpoints"
977
+ f" from {filename_start} (excluded) to {filename_end}"
978
+ )
979
+ model.to(device)
980
+ model.load_state_dict(
981
+ average_checkpoints_with_averaged_model(
982
+ filename_start=filename_start,
983
+ filename_end=filename_end,
984
+ device=device,
985
+ )
986
+ )
987
+ else:
988
+ assert params.avg > 0, params.avg
989
+ start = params.epoch - params.avg
990
+ assert start >= 1, start
991
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
992
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
993
+ logging.info(
994
+ f"Calculating the averaged model over epoch range from "
995
+ f"{start} (excluded) to {params.epoch}"
996
+ )
997
+ model.to(device)
998
+ model.load_state_dict(
999
+ average_checkpoints_with_averaged_model(
1000
+ filename_start=filename_start,
1001
+ filename_end=filename_end,
1002
+ device=device,
1003
+ )
1004
+ )
1005
+
1006
+ model.to(device)
1007
+ model.eval()
1008
+
1009
+ num_param = sum([p.numel() for p in model.parameters()])
1010
+ logging.info(f"Number of model parameters: {num_param}")
1011
+
1012
+ # we need cut ids to display recognition results.
1013
+ args.return_cuts = True
1014
+ librispeech = LibriSpeechAsrDataModule(args)
1015
+
1016
+ test_clean_cuts = librispeech.test_clean_cuts()
1017
+ #test_other_cuts = librispeech.test_other_cuts()
1018
+
1019
+ test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
1020
+ #test_other_dl = librispeech.test_dataloaders(test_other_cuts)
1021
+
1022
+ #test_sets = ["test-clean", "test-other"]
1023
+ #test_dl = [test_clean_dl, test_other_dl]
1024
+
1025
+ test_sets = ["test-clean"]
1026
+ test_dl = [test_clean_dl]
1027
+
1028
+ for test_set, test_dl in zip(test_sets, test_dl):
1029
+ results_dict = decode_dataset(
1030
+ dl=test_dl,
1031
+ params=params,
1032
+ model=model,
1033
+ HLG=HLG,
1034
+ H=H,
1035
+ bpe_model=bpe_model,
1036
+ word_table=lexicon.word_table,
1037
+ G=G,
1038
+ sos_id=sos_id,
1039
+ eos_id=eos_id,
1040
+ )
1041
+
1042
+ save_results(
1043
+ params=params,
1044
+ test_set_name=test_set,
1045
+ results_dict=results_dict,
1046
+ )
1047
+
1048
+ logging.info("Done!")
1049
+
1050
+
1051
+ if __name__ == "__main__":
1052
+ main()
err2020/conformer_ctc3/encoder_interface.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from typing import Tuple
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+
23
+ class EncoderInterface(nn.Module):
24
+ def forward(
25
+ self, x: torch.Tensor, x_lens: torch.Tensor
26
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
27
+ """
28
+ Args:
29
+ x:
30
+ A tensor of shape (batch_size, input_seq_len, num_features)
31
+ containing the input features.
32
+ x_lens:
33
+ A tensor of shape (batch_size,) containing the number of frames
34
+ in `x` before padding.
35
+ Returns:
36
+ Return a tuple containing two tensors:
37
+ - encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
38
+ containing unnormalized probabilities, i.e., the output of a
39
+ linear layer.
40
+ - encoder_out_lens, a tensor of shape (batch_size,) containing
41
+ the number of frames in `encoder_out` before padding.
42
+ """
43
+ raise NotImplementedError("Please implement it in a subclass")
err2020/conformer_ctc3/exp/jit_trace.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d364f175859dcfff11dcd5eea7032f568b77af7c9ffa15ff9f405c69983d58b
3
+ size 330828854
err2020/conformer_ctc3/export.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ # This script converts several saved checkpoints
20
+ # to a single one using model averaging.
21
+ """
22
+ Usage:
23
+
24
+ (1) Export to torchscript model using torch.jit.trace()
25
+
26
+ ./conformer_ctc3/export.py \
27
+ --exp-dir ./conformer_ctc3/exp \
28
+ --lang-dir data/lang_bpe_500 \
29
+ --epoch 20 \
30
+ --avg 10 \
31
+ --jit-trace 1
32
+
33
+ It will generates the file: `jit_trace.pt`.
34
+
35
+ (2) Export `model.state_dict()`
36
+
37
+ ./conformer_ctc3/export.py \
38
+ --exp-dir ./conformer_ctc3/exp \
39
+ --lang-dir data/lang_bpe_500 \
40
+ --epoch 20 \
41
+ --avg 10
42
+
43
+ It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
44
+ load it by `icefall.checkpoint.load_checkpoint()`.
45
+
46
+ To use the generated file with `conformer_ctc3/decode.py`,
47
+ you can do:
48
+
49
+ cd /path/to/exp_dir
50
+ ln -s pretrained.pt epoch-9999.pt
51
+
52
+ cd /path/to/egs/librispeech/ASR
53
+ ./conformer_ctc3/decode.py \
54
+ --exp-dir ./conformer_ctc3/exp \
55
+ --epoch 9999 \
56
+ --avg 1 \
57
+ --max-duration 100 \
58
+ --lang-dir data/lang_bpe_500
59
+ """
60
+
61
+ import argparse
62
+ import logging
63
+ from pathlib import Path
64
+
65
+ import torch
66
+ from scaling_converter import convert_scaled_to_non_scaled
67
+ from train import add_model_arguments, get_ctc_model, get_params
68
+
69
+ from icefall.checkpoint import (
70
+ average_checkpoints,
71
+ average_checkpoints_with_averaged_model,
72
+ find_checkpoints,
73
+ load_checkpoint,
74
+ )
75
+ from icefall.lexicon import Lexicon
76
+ from icefall.utils import str2bool
77
+
78
+
79
+ def get_parser():
80
+ parser = argparse.ArgumentParser(
81
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
82
+ )
83
+
84
+ parser.add_argument(
85
+ "--epoch",
86
+ type=int,
87
+ default=28,
88
+ help="""It specifies the checkpoint to use for averaging.
89
+ Note: Epoch counts from 0.
90
+ You can specify --avg to use more checkpoints for model averaging.""",
91
+ )
92
+
93
+ parser.add_argument(
94
+ "--iter",
95
+ type=int,
96
+ default=0,
97
+ help="""If positive, --epoch is ignored and it
98
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
99
+ You can specify --avg to use more checkpoints for model averaging.
100
+ """,
101
+ )
102
+
103
+ parser.add_argument(
104
+ "--avg",
105
+ type=int,
106
+ default=15,
107
+ help="Number of checkpoints to average. Automatically select "
108
+ "consecutive checkpoints before the checkpoint specified by "
109
+ "'--epoch' and '--iter'",
110
+ )
111
+
112
+ parser.add_argument(
113
+ "--use-averaged-model",
114
+ type=str2bool,
115
+ default=True,
116
+ help="Whether to load averaged model. Currently it only supports "
117
+ "using --epoch. If True, it would decode with the averaged model "
118
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
119
+ "Actually only the models with epoch number of `epoch-avg` and "
120
+ "`epoch` are loaded for averaging. ",
121
+ )
122
+
123
+ parser.add_argument(
124
+ "--exp-dir",
125
+ type=str,
126
+ default="pruned_transducer_stateless4/exp",
127
+ help="""It specifies the directory where all training related
128
+ files, e.g., checkpoints, log, etc, are saved
129
+ """,
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--lang-dir",
134
+ type=Path,
135
+ default="data/lang_bpe_500",
136
+ help="The lang dir containing word table and LG graph",
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--jit-trace",
141
+ type=str2bool,
142
+ default=False,
143
+ help="""True to save a model after applying torch.jit.script.
144
+ """,
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--streaming-model",
149
+ type=str2bool,
150
+ default=False,
151
+ help="""Whether to export a streaming model, if the models in exp-dir
152
+ are streaming model, this should be True.
153
+ """,
154
+ )
155
+
156
+ add_model_arguments(parser)
157
+
158
+ return parser
159
+
160
+
161
+ def main():
162
+ args = get_parser().parse_args()
163
+ args.exp_dir = Path(args.exp_dir)
164
+
165
+ params = get_params()
166
+ params.update(vars(args))
167
+
168
+ device = torch.device("cpu")
169
+ if torch.cuda.is_available():
170
+ device = torch.device("cuda", 0)
171
+
172
+ logging.info(f"device: {device}")
173
+
174
+ lexicon = Lexicon(params.lang_dir)
175
+ max_token_id = max(lexicon.tokens)
176
+ num_classes = max_token_id + 1 # +1 for the blank
177
+ params.vocab_size = num_classes
178
+
179
+ if params.streaming_model:
180
+ assert params.causal_convolution
181
+
182
+ logging.info(params)
183
+
184
+ logging.info("About to create model")
185
+ model = get_ctc_model(params)
186
+
187
+ model.to(device)
188
+
189
+ if not params.use_averaged_model:
190
+ if params.iter > 0:
191
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
192
+ : params.avg
193
+ ]
194
+ if len(filenames) == 0:
195
+ raise ValueError(
196
+ f"No checkpoints found for"
197
+ f" --iter {params.iter}, --avg {params.avg}"
198
+ )
199
+ elif len(filenames) < params.avg:
200
+ raise ValueError(
201
+ f"Not enough checkpoints ({len(filenames)}) found for"
202
+ f" --iter {params.iter}, --avg {params.avg}"
203
+ )
204
+ logging.info(f"averaging {filenames}")
205
+ model.load_state_dict(average_checkpoints(filenames, device=device))
206
+ elif params.avg == 1:
207
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
208
+ else:
209
+ start = params.epoch - params.avg + 1
210
+ filenames = []
211
+ for i in range(start, params.epoch + 1):
212
+ if i >= 1:
213
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
214
+ logging.info(f"averaging {filenames}")
215
+ model.load_state_dict(average_checkpoints(filenames, device=device))
216
+ else:
217
+ if params.iter > 0:
218
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
219
+ : params.avg + 1
220
+ ]
221
+ if len(filenames) == 0:
222
+ raise ValueError(
223
+ f"No checkpoints found for"
224
+ f" --iter {params.iter}, --avg {params.avg}"
225
+ )
226
+ elif len(filenames) < params.avg + 1:
227
+ raise ValueError(
228
+ f"Not enough checkpoints ({len(filenames)}) found for"
229
+ f" --iter {params.iter}, --avg {params.avg}"
230
+ )
231
+ filename_start = filenames[-1]
232
+ filename_end = filenames[0]
233
+ logging.info(
234
+ "Calculating the averaged model over iteration checkpoints"
235
+ f" from {filename_start} (excluded) to {filename_end}"
236
+ )
237
+ model.load_state_dict(
238
+ average_checkpoints_with_averaged_model(
239
+ filename_start=filename_start,
240
+ filename_end=filename_end,
241
+ device=device,
242
+ )
243
+ )
244
+ else:
245
+ assert params.avg > 0, params.avg
246
+ start = params.epoch - params.avg
247
+ assert start >= 1, start
248
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
249
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
250
+ logging.info(
251
+ f"Calculating the averaged model over epoch range from "
252
+ f"{start} (excluded) to {params.epoch}"
253
+ )
254
+ model.load_state_dict(
255
+ average_checkpoints_with_averaged_model(
256
+ filename_start=filename_start,
257
+ filename_end=filename_end,
258
+ device=device,
259
+ )
260
+ )
261
+
262
+ model.to("cpu")
263
+ model.eval()
264
+
265
+ if params.jit_trace:
266
+ # TODO: will support streaming mode
267
+ assert not params.streaming_model
268
+ convert_scaled_to_non_scaled(model, inplace=True)
269
+
270
+ logging.info("Using torch.jit.trace()")
271
+
272
+ x = torch.zeros(1, 100, 80, dtype=torch.float32)
273
+ x_lens = torch.tensor([100], dtype=torch.int64)
274
+ traced_model = torch.jit.trace(model, (x, x_lens))
275
+
276
+ filename = params.exp_dir / "jit_trace.pt"
277
+ traced_model.save(str(filename))
278
+ logging.info(f"Saved to {filename}")
279
+ else:
280
+ logging.info("Not using torch.jit.trace()")
281
+ # Save it using a format so that it can be loaded
282
+ # by :func:`load_checkpoint`
283
+ filename = params.exp_dir / "pretrained.pt"
284
+ torch.save({"model": model.state_dict()}, str(filename))
285
+ logging.info(f"Saved to {filename}")
286
+
287
+
288
+ if __name__ == "__main__":
289
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
290
+
291
+ logging.basicConfig(format=formatter, level=logging.INFO)
292
+ main()
err2020/conformer_ctc3/jit_pretrained.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
3
+ # Mingshuang Luo,)
4
+ # Zengwei Yao)
5
+ #
6
+ # See ../../../../LICENSE for clarification regarding multiple authors
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+
21
+ """
22
+ Usage (for non-streaming mode):
23
+
24
+ (1) ctc-decoding
25
+ ./conformer_ctc3/pretrained.py \
26
+ --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
27
+ --bpe-model data/lang_bpe_500/bpe.model \
28
+ --method ctc-decoding \
29
+ --sample-rate 16000 \
30
+ /path/to/foo.wav \
31
+ /path/to/bar.wav
32
+
33
+ (2) 1best
34
+ ./conformer_ctc3/pretrained.py \
35
+ --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
36
+ --HLG data/lang_bpe_500/HLG.pt \
37
+ --words-file data/lang_bpe_500/words.txt \
38
+ --method 1best \
39
+ --sample-rate 16000 \
40
+ /path/to/foo.wav \
41
+ /path/to/bar.wav
42
+
43
+ (3) nbest-rescoring
44
+ ./conformer_ctc3/pretrained.py \
45
+ --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
46
+ --HLG data/lang_bpe_500/HLG.pt \
47
+ --words-file data/lang_bpe_500/words.txt \
48
+ --G data/lm/G_4_gram.pt \
49
+ --method nbest-rescoring \
50
+ --sample-rate 16000 \
51
+ /path/to/foo.wav \
52
+ /path/to/bar.wav
53
+
54
+ (4) whole-lattice-rescoring
55
+ ./conformer_ctc3/pretrained.py \
56
+ --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
57
+ --HLG data/lang_bpe_500/HLG.pt \
58
+ --words-file data/lang_bpe_500/words.txt \
59
+ --G data/lm/G_4_gram.pt \
60
+ --method whole-lattice-rescoring \
61
+ --sample-rate 16000 \
62
+ /path/to/foo.wav \
63
+ /path/to/bar.wav
64
+ """
65
+
66
+
67
+ import argparse
68
+ import logging
69
+ import math
70
+ from typing import List
71
+
72
+ import k2
73
+ import kaldifeat
74
+ import sentencepiece as spm
75
+ import torch
76
+ import torchaudio
77
+ from decode import get_decoding_params
78
+ from torch.nn.utils.rnn import pad_sequence
79
+ from train import add_model_arguments, get_params
80
+
81
+ from icefall.decode import (
82
+ get_lattice,
83
+ one_best_decoding,
84
+ rescore_with_n_best_list,
85
+ rescore_with_whole_lattice,
86
+ )
87
+ from icefall.utils import get_texts
88
+
89
+
90
+ def get_parser():
91
+ parser = argparse.ArgumentParser(
92
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
93
+ )
94
+
95
+ parser.add_argument(
96
+ "--model-filename",
97
+ type=str,
98
+ required=True,
99
+ help="Path to the torchscript model.",
100
+ )
101
+
102
+ parser.add_argument(
103
+ "--words-file",
104
+ type=str,
105
+ help="""Path to words.txt.
106
+ Used only when method is not ctc-decoding.
107
+ """,
108
+ )
109
+
110
+ parser.add_argument(
111
+ "--HLG",
112
+ type=str,
113
+ help="""Path to HLG.pt.
114
+ Used only when method is not ctc-decoding.
115
+ """,
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--bpe-model",
120
+ type=str,
121
+ help="""Path to bpe.model.
122
+ Used only when method is ctc-decoding.
123
+ """,
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--method",
128
+ type=str,
129
+ default="1best",
130
+ help="""Decoding method.
131
+ Possible values are:
132
+ (0) ctc-decoding - Use CTC decoding. It uses a sentence
133
+ piece model, i.e., lang_dir/bpe.model, to convert
134
+ word pieces to words. It needs neither a lexicon
135
+ nor an n-gram LM.
136
+ (1) 1best - Use the best path as decoding output. Only
137
+ the transformer encoder output is used for decoding.
138
+ We call it HLG decoding.
139
+ (2) nbest-rescoring. Extract n paths from the decoding lattice,
140
+ rescore them with an LM, the path with
141
+ the highest score is the decoding result.
142
+ We call it HLG decoding + n-gram LM rescoring.
143
+ (3) whole-lattice-rescoring - Use an LM to rescore the
144
+ decoding lattice and then use 1best to decode the
145
+ rescored lattice.
146
+ We call it HLG decoding + n-gram LM rescoring.
147
+ """,
148
+ )
149
+
150
+ parser.add_argument(
151
+ "--G",
152
+ type=str,
153
+ help="""An LM for rescoring.
154
+ Used only when method is
155
+ whole-lattice-rescoring or nbest-rescoring.
156
+ It's usually a 4-gram LM.
157
+ """,
158
+ )
159
+
160
+ parser.add_argument(
161
+ "--num-paths",
162
+ type=int,
163
+ default=100,
164
+ help="""
165
+ Used only when method is attention-decoder.
166
+ It specifies the size of n-best list.""",
167
+ )
168
+
169
+ parser.add_argument(
170
+ "--ngram-lm-scale",
171
+ type=float,
172
+ default=1.3,
173
+ help="""
174
+ Used only when method is whole-lattice-rescoring and nbest-rescoring.
175
+ It specifies the scale for n-gram LM scores.
176
+ (Note: You need to tune it on a dataset.)
177
+ """,
178
+ )
179
+
180
+ parser.add_argument(
181
+ "--nbest-scale",
182
+ type=float,
183
+ default=0.5,
184
+ help="""
185
+ Used only when method is nbest-rescoring.
186
+ It specifies the scale for lattice.scores when
187
+ extracting n-best lists. A smaller value results in
188
+ more unique number of paths with the risk of missing
189
+ the best path.
190
+ """,
191
+ )
192
+
193
+ parser.add_argument(
194
+ "--num-classes",
195
+ type=int,
196
+ default=500,
197
+ help="""
198
+ Vocab size in the BPE model.
199
+ """,
200
+ )
201
+
202
+ parser.add_argument(
203
+ "--sample-rate",
204
+ type=int,
205
+ default=16000,
206
+ help="The sample rate of the input sound file",
207
+ )
208
+
209
+ parser.add_argument(
210
+ "sound_files",
211
+ type=str,
212
+ nargs="+",
213
+ help="The input sound file(s) to transcribe. "
214
+ "Supported formats are those supported by torchaudio.load(). "
215
+ "For example, wav and flac are supported. "
216
+ "The sample rate has to be 16kHz.",
217
+ )
218
+
219
+ add_model_arguments(parser)
220
+
221
+ return parser
222
+
223
+
224
+ def read_sound_files(
225
+ filenames: List[str], expected_sample_rate: float
226
+ ) -> List[torch.Tensor]:
227
+ """Read a list of sound files into a list 1-D float32 torch tensors.
228
+ Args:
229
+ filenames:
230
+ A list of sound filenames.
231
+ expected_sample_rate:
232
+ The expected sample rate of the sound files.
233
+ Returns:
234
+ Return a list of 1-D float32 torch tensors.
235
+ """
236
+ ans = []
237
+ for f in filenames:
238
+ wave, sample_rate = torchaudio.load(f)
239
+ assert sample_rate == expected_sample_rate, (
240
+ f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
241
+ )
242
+ # We use only the first channel
243
+ ans.append(wave[0])
244
+ return ans
245
+
246
+
247
+ def main():
248
+ parser = get_parser()
249
+ args = parser.parse_args()
250
+
251
+ params = get_params()
252
+ # add decoding params
253
+ params.update(get_decoding_params())
254
+ params.update(vars(args))
255
+ params.vocab_size = params.num_classes
256
+
257
+ logging.info(f"{params}")
258
+
259
+ device = torch.device("cpu")
260
+
261
+ logging.info(f"device: {device}")
262
+
263
+ model = torch.jit.load(args.model_filename)
264
+ model.to(device)
265
+ model.eval()
266
+
267
+ logging.info("Constructing Fbank computer")
268
+ opts = kaldifeat.FbankOptions()
269
+ opts.device = device
270
+ opts.frame_opts.dither = 0
271
+ opts.frame_opts.snip_edges = False
272
+ opts.frame_opts.samp_freq = params.sample_rate
273
+ opts.mel_opts.num_bins = params.feature_dim
274
+
275
+ fbank = kaldifeat.Fbank(opts)
276
+
277
+ logging.info(f"Reading sound files: {params.sound_files}")
278
+ waves = read_sound_files(
279
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
280
+ )
281
+ waves = [w.to(device) for w in waves]
282
+
283
+ logging.info("Decoding started")
284
+ features = fbank(waves)
285
+ feature_lengths = [f.size(0) for f in features]
286
+
287
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
288
+ feature_lengths = torch.tensor(feature_lengths, device=device)
289
+
290
+ nnet_output, _ = model(features, feature_lengths)
291
+
292
+ batch_size = nnet_output.shape[0]
293
+ supervision_segments = torch.tensor(
294
+ [
295
+ [i, 0, feature_lengths[i] // params.subsampling_factor]
296
+ for i in range(batch_size)
297
+ ],
298
+ dtype=torch.int32,
299
+ )
300
+
301
+ if params.method == "ctc-decoding":
302
+ logging.info("Use CTC decoding")
303
+ bpe_model = spm.SentencePieceProcessor()
304
+ bpe_model.load(params.bpe_model)
305
+ max_token_id = params.num_classes - 1
306
+
307
+ H = k2.ctc_topo(
308
+ max_token=max_token_id,
309
+ modified=False,
310
+ device=device,
311
+ )
312
+
313
+ lattice = get_lattice(
314
+ nnet_output=nnet_output,
315
+ decoding_graph=H,
316
+ supervision_segments=supervision_segments,
317
+ search_beam=params.search_beam,
318
+ output_beam=params.output_beam,
319
+ min_active_states=params.min_active_states,
320
+ max_active_states=params.max_active_states,
321
+ subsampling_factor=params.subsampling_factor,
322
+ )
323
+
324
+ best_path = one_best_decoding(
325
+ lattice=lattice, use_double_scores=params.use_double_scores
326
+ )
327
+ token_ids = get_texts(best_path)
328
+ hyps = bpe_model.decode(token_ids)
329
+ hyps = [s.split() for s in hyps]
330
+ elif params.method in [
331
+ "1best",
332
+ "nbest-rescoring",
333
+ "whole-lattice-rescoring",
334
+ ]:
335
+ logging.info(f"Loading HLG from {params.HLG}")
336
+ HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
337
+ HLG = HLG.to(device)
338
+ if not hasattr(HLG, "lm_scores"):
339
+ # For whole-lattice-rescoring and attention-decoder
340
+ HLG.lm_scores = HLG.scores.clone()
341
+
342
+ if params.method in [
343
+ "nbest-rescoring",
344
+ "whole-lattice-rescoring",
345
+ ]:
346
+ logging.info(f"Loading G from {params.G}")
347
+ G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
348
+ G = G.to(device)
349
+ if params.method == "whole-lattice-rescoring":
350
+ # Add epsilon self-loops to G as we will compose
351
+ # it with the whole lattice later
352
+ G = k2.add_epsilon_self_loops(G)
353
+ G = k2.arc_sort(G)
354
+
355
+ # G.lm_scores is used to replace HLG.lm_scores during
356
+ # LM rescoring.
357
+ G.lm_scores = G.scores.clone()
358
+
359
+ lattice = get_lattice(
360
+ nnet_output=nnet_output,
361
+ decoding_graph=HLG,
362
+ supervision_segments=supervision_segments,
363
+ search_beam=params.search_beam,
364
+ output_beam=params.output_beam,
365
+ min_active_states=params.min_active_states,
366
+ max_active_states=params.max_active_states,
367
+ subsampling_factor=params.subsampling_factor,
368
+ )
369
+
370
+ if params.method == "1best":
371
+ logging.info("Use HLG decoding")
372
+ best_path = one_best_decoding(
373
+ lattice=lattice, use_double_scores=params.use_double_scores
374
+ )
375
+ if params.method == "nbest-rescoring":
376
+ logging.info("Use HLG decoding + LM rescoring")
377
+ best_path_dict = rescore_with_n_best_list(
378
+ lattice=lattice,
379
+ G=G,
380
+ num_paths=params.num_paths,
381
+ lm_scale_list=[params.ngram_lm_scale],
382
+ nbest_scale=params.nbest_scale,
383
+ )
384
+ best_path = next(iter(best_path_dict.values()))
385
+ elif params.method == "whole-lattice-rescoring":
386
+ logging.info("Use HLG decoding + LM rescoring")
387
+ best_path_dict = rescore_with_whole_lattice(
388
+ lattice=lattice,
389
+ G_with_epsilon_loops=G,
390
+ lm_scale_list=[params.ngram_lm_scale],
391
+ )
392
+ best_path = next(iter(best_path_dict.values()))
393
+
394
+ hyps = get_texts(best_path)
395
+ word_sym_table = k2.SymbolTable.from_file(params.words_file)
396
+ hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
397
+ else:
398
+ raise ValueError(f"Unsupported decoding method: {params.method}")
399
+
400
+ s = "\n"
401
+ for filename, hyp in zip(params.sound_files, hyps):
402
+ words = " ".join(hyp)
403
+ s += f"{filename}:\n{words}\n\n"
404
+ logging.info(s)
405
+
406
+ logging.info("Decoding Done")
407
+
408
+
409
+ if __name__ == "__main__":
410
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
411
+
412
+ logging.basicConfig(format=formatter, level=logging.INFO)
413
+ main()
err2020/conformer_ctc3/lstmp.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class LSTMP(nn.Module):
9
+ """LSTM with projection.
10
+
11
+ PyTorch does not support exporting LSTM with projection to ONNX.
12
+ This class reimplements LSTM with projection using basic matrix-matrix
13
+ and matrix-vector operations. It is not intended for training.
14
+ """
15
+
16
+ def __init__(self, lstm: nn.LSTM):
17
+ """
18
+ Args:
19
+ lstm:
20
+ LSTM with proj_size. We support only uni-directional,
21
+ 1-layer LSTM with projection at present.
22
+ """
23
+ super().__init__()
24
+ assert lstm.bidirectional is False, lstm.bidirectional
25
+ assert lstm.num_layers == 1, lstm.num_layers
26
+ assert 0 < lstm.proj_size < lstm.hidden_size, (
27
+ lstm.proj_size,
28
+ lstm.hidden_size,
29
+ )
30
+
31
+ assert lstm.batch_first is False, lstm.batch_first
32
+
33
+ state_dict = lstm.state_dict()
34
+
35
+ w_ih = state_dict["weight_ih_l0"]
36
+ w_hh = state_dict["weight_hh_l0"]
37
+
38
+ b_ih = state_dict["bias_ih_l0"]
39
+ b_hh = state_dict["bias_hh_l0"]
40
+
41
+ w_hr = state_dict["weight_hr_l0"]
42
+ self.input_size = lstm.input_size
43
+ self.proj_size = lstm.proj_size
44
+ self.hidden_size = lstm.hidden_size
45
+
46
+ self.w_ih = w_ih
47
+ self.w_hh = w_hh
48
+ self.b = b_ih + b_hh
49
+ self.w_hr = w_hr
50
+
51
+ def forward(
52
+ self,
53
+ input: torch.Tensor,
54
+ hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
55
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
56
+ """
57
+ Args:
58
+ input:
59
+ A tensor of shape [T, N, hidden_size]
60
+ hx:
61
+ A tuple containing:
62
+ - h0: a tensor of shape (1, N, proj_size)
63
+ - c0: a tensor of shape (1, N, hidden_size)
64
+ Returns:
65
+ Return a tuple containing:
66
+ - output: a tensor of shape (T, N, proj_size).
67
+ - A tuple containing:
68
+ - h: a tensor of shape (1, N, proj_size)
69
+ - c: a tensor of shape (1, N, hidden_size)
70
+
71
+ """
72
+ x_list = input.unbind(dim=0) # We use batch_first=False
73
+
74
+ if hx is not None:
75
+ h0, c0 = hx
76
+ else:
77
+ h0 = torch.zeros(1, input.size(1), self.proj_size)
78
+ c0 = torch.zeros(1, input.size(1), self.hidden_size)
79
+ h0 = h0.squeeze(0)
80
+ c0 = c0.squeeze(0)
81
+ y_list = []
82
+ for x in x_list:
83
+ gates = F.linear(x, self.w_ih, self.b) + F.linear(h0, self.w_hh)
84
+ i, f, g, o = gates.chunk(4, dim=1)
85
+
86
+ i = i.sigmoid()
87
+ f = f.sigmoid()
88
+ g = g.tanh()
89
+ o = o.sigmoid()
90
+
91
+ c = f * c0 + i * g
92
+ h = o * c.tanh()
93
+
94
+ h = F.linear(h, self.w_hr)
95
+ y_list.append(h)
96
+
97
+ c0 = c
98
+ h0 = h
99
+
100
+ y = torch.stack(y_list, dim=0)
101
+
102
+ return y, (h0.unsqueeze(0), c0.unsqueeze(0))
err2020/conformer_ctc3/model.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
2
+ # Wei Kang,
3
+ # Zengwei Yao)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ import math
21
+ from typing import Tuple
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from encoder_interface import EncoderInterface
26
+ from scaling import ScaledLinear
27
+
28
+
29
+ class CTCModel(nn.Module):
30
+ """It implements https://www.cs.toronto.edu/~graves/icml_2006.pdf
31
+ "Connectionist Temporal Classification: Labelling Unsegmented
32
+ Sequence Data with Recurrent Neural Networks"
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ encoder: EncoderInterface,
38
+ encoder_dim: int,
39
+ vocab_size: int,
40
+ ):
41
+ """
42
+ Args:
43
+ encoder:
44
+ It is the transcription network in the paper. Its accepts
45
+ two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
46
+ It returns two tensors: `logits` of shape (N, T, encoder_dm) and
47
+ `logit_lens` of shape (N,).
48
+ encoder_dim:
49
+ The feature embedding dimension.
50
+ vocab_size:
51
+ The vocabulary size.
52
+ """
53
+ super().__init__()
54
+ assert isinstance(encoder, EncoderInterface), type(encoder)
55
+
56
+ self.encoder = encoder
57
+ self.ctc_output_module = nn.Sequential(
58
+ nn.Dropout(p=0.1),
59
+ ScaledLinear(encoder_dim, vocab_size),
60
+ )
61
+
62
+ def get_ctc_output(
63
+ self,
64
+ encoder_out: torch.Tensor,
65
+ delay_penalty: float = 0.0,
66
+ blank_threshold: float = 0.99,
67
+ ):
68
+ """Compute ctc log-prob and optionally (delay_penalty > 0) apply delay penalty.
69
+ We first split utterance into sub-utterances according to the
70
+ blank probs, and then add sawtooth-like "blank-bonus" values to
71
+ the blank probs.
72
+ See https://github.com/k2-fsa/icefall/pull/669 for details.
73
+
74
+ Args:
75
+ encoder_out:
76
+ A tensor with shape of (N, T, C).
77
+ delay_penalty:
78
+ A constant used to scale the delay penalty score.
79
+ blank_threshold:
80
+ The threshold used to split utterance into sub-utterances.
81
+ """
82
+ output = self.ctc_output_module(encoder_out)
83
+ log_prob = nn.functional.log_softmax(output, dim=-1)
84
+
85
+ if self.training and delay_penalty > 0:
86
+ T_arange = torch.arange(encoder_out.shape[1]).to(device=encoder_out.device)
87
+ # split into sub-utterances using the blank-id
88
+ mask = log_prob[:, :, 0] >= math.log(blank_threshold) # (B, T)
89
+ mask[:, 0] = True
90
+ cummax_out = (T_arange * mask).cummax(dim=-1)[0] # (B, T)
91
+ # the sawtooth "blank-bonus" value
92
+ penalty = T_arange - cummax_out # (B, T)
93
+ penalty_all = torch.zeros_like(log_prob)
94
+ penalty_all[:, :, 0] = delay_penalty * penalty
95
+ # apply latency penalty on probs
96
+ log_prob = log_prob + penalty_all
97
+
98
+ return log_prob
99
+
100
+ def forward(
101
+ self,
102
+ x: torch.Tensor,
103
+ x_lens: torch.Tensor,
104
+ warmup: float = 1.0,
105
+ delay_penalty: float = 0.0,
106
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
107
+ """
108
+ Args:
109
+ x:
110
+ A 3-D tensor of shape (N, T, C).
111
+ x_lens:
112
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
113
+ before padding.
114
+ warmup: a floating point value which increases throughout training;
115
+ values >= 1.0 are fully warmed up and have all modules present.
116
+ delay_penalty:
117
+ A constant used to scale the delay penalty score.
118
+ """
119
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
120
+ assert torch.all(encoder_out_lens > 0)
121
+ nnet_output = self.get_ctc_output(encoder_out, delay_penalty=delay_penalty)
122
+ return nnet_output, encoder_out_lens
err2020/conformer_ctc3/optim.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ import torch
21
+ from torch.optim import Optimizer
22
+
23
+
24
+ class Eve(Optimizer):
25
+ r"""
26
+ Implements Eve algorithm. This is a modified version of AdamW with a special
27
+ way of setting the weight-decay / shrinkage-factor, which is designed to make the
28
+ rms of the parameters approach a particular target_rms (default: 0.1). This is
29
+ for use with networks with 'scaled' versions of modules (see scaling.py), which
30
+ will be close to invariant to the absolute scale on the parameter matrix.
31
+
32
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
33
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
34
+ Eve is unpublished so far.
35
+
36
+ Arguments:
37
+ params (iterable): iterable of parameters to optimize or dicts defining
38
+ parameter groups
39
+ lr (float, optional): learning rate (default: 1e-3)
40
+ betas (Tuple[float, float], optional): coefficients used for computing
41
+ running averages of gradient and its square (default: (0.9, 0.999))
42
+ eps (float, optional): term added to the denominator to improve
43
+ numerical stability (default: 1e-8)
44
+ weight_decay (float, optional): weight decay coefficient (default: 3e-4;
45
+ this value means that the weight would decay significantly after
46
+ about 3k minibatches. Is not multiplied by learning rate, but
47
+ is conditional on RMS-value of parameter being > target_rms.
48
+ target_rms (float, optional): target root-mean-square value of
49
+ parameters, if they fall below this we will stop applying weight decay.
50
+
51
+
52
+ .. _Adam\: A Method for Stochastic Optimization:
53
+ https://arxiv.org/abs/1412.6980
54
+ .. _Decoupled Weight Decay Regularization:
55
+ https://arxiv.org/abs/1711.05101
56
+ .. _On the Convergence of Adam and Beyond:
57
+ https://openreview.net/forum?id=ryQu7f-RZ
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ params,
63
+ lr=1e-3,
64
+ betas=(0.9, 0.98),
65
+ eps=1e-8,
66
+ weight_decay=1e-3,
67
+ target_rms=0.1,
68
+ ):
69
+
70
+ if not 0.0 <= lr:
71
+ raise ValueError("Invalid learning rate: {}".format(lr))
72
+ if not 0.0 <= eps:
73
+ raise ValueError("Invalid epsilon value: {}".format(eps))
74
+ if not 0.0 <= betas[0] < 1.0:
75
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
76
+ if not 0.0 <= betas[1] < 1.0:
77
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
78
+ if not 0 <= weight_decay <= 0.1:
79
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
80
+ if not 0 < target_rms <= 10.0:
81
+ raise ValueError("Invalid target_rms value: {}".format(target_rms))
82
+ defaults = dict(
83
+ lr=lr,
84
+ betas=betas,
85
+ eps=eps,
86
+ weight_decay=weight_decay,
87
+ target_rms=target_rms,
88
+ )
89
+ super(Eve, self).__init__(params, defaults)
90
+
91
+ def __setstate__(self, state):
92
+ super(Eve, self).__setstate__(state)
93
+
94
+ @torch.no_grad()
95
+ def step(self, closure=None):
96
+ """Performs a single optimization step.
97
+
98
+ Arguments:
99
+ closure (callable, optional): A closure that reevaluates the model
100
+ and returns the loss.
101
+ """
102
+ loss = None
103
+ if closure is not None:
104
+ with torch.enable_grad():
105
+ loss = closure()
106
+
107
+ for group in self.param_groups:
108
+ for p in group["params"]:
109
+ if p.grad is None:
110
+ continue
111
+
112
+ # Perform optimization step
113
+ grad = p.grad
114
+ if grad.is_sparse:
115
+ raise RuntimeError("AdamW does not support sparse gradients")
116
+
117
+ state = self.state[p]
118
+
119
+ # State initialization
120
+ if len(state) == 0:
121
+ state["step"] = 0
122
+ # Exponential moving average of gradient values
123
+ state["exp_avg"] = torch.zeros_like(
124
+ p, memory_format=torch.preserve_format
125
+ )
126
+ # Exponential moving average of squared gradient values
127
+ state["exp_avg_sq"] = torch.zeros_like(
128
+ p, memory_format=torch.preserve_format
129
+ )
130
+
131
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
132
+
133
+ beta1, beta2 = group["betas"]
134
+
135
+ state["step"] += 1
136
+ bias_correction1 = 1 - beta1 ** state["step"]
137
+ bias_correction2 = 1 - beta2 ** state["step"]
138
+
139
+ # Decay the first and second moment running average coefficient
140
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
141
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
142
+ denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_(
143
+ group["eps"]
144
+ )
145
+
146
+ step_size = group["lr"] / bias_correction1
147
+ target_rms = group["target_rms"]
148
+ weight_decay = group["weight_decay"]
149
+
150
+ if p.numel() > 1:
151
+ # avoid applying this weight-decay on "scaling factors"
152
+ # (which are scalar).
153
+ is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
154
+ p.mul_(1 - (weight_decay * is_above_target_rms))
155
+ p.addcdiv_(exp_avg, denom, value=-step_size)
156
+
157
+ # Constrain the range of scalar weights
158
+ if p.numel() == 1:
159
+ p.clamp_(min=-10, max=2)
160
+
161
+ return loss
162
+
163
+
164
+ class LRScheduler(object):
165
+ """
166
+ Base-class for learning rate schedulers where the learning-rate depends on both the
167
+ batch and the epoch.
168
+ """
169
+
170
+ def __init__(self, optimizer: Optimizer, verbose: bool = False):
171
+ # Attach optimizer
172
+ if not isinstance(optimizer, Optimizer):
173
+ raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
174
+ self.optimizer = optimizer
175
+ self.verbose = verbose
176
+
177
+ for group in optimizer.param_groups:
178
+ group.setdefault("initial_lr", group["lr"])
179
+
180
+ self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups]
181
+
182
+ self.epoch = 0
183
+ self.batch = 0
184
+
185
+ def state_dict(self):
186
+ """Returns the state of the scheduler as a :class:`dict`.
187
+
188
+ It contains an entry for every variable in self.__dict__ which
189
+ is not the optimizer.
190
+ """
191
+ return {
192
+ "base_lrs": self.base_lrs,
193
+ "epoch": self.epoch,
194
+ "batch": self.batch,
195
+ }
196
+
197
+ def load_state_dict(self, state_dict):
198
+ """Loads the schedulers state.
199
+
200
+ Args:
201
+ state_dict (dict): scheduler state. Should be an object returned
202
+ from a call to :meth:`state_dict`.
203
+ """
204
+ self.__dict__.update(state_dict)
205
+
206
+ def get_last_lr(self) -> List[float]:
207
+ """Return last computed learning rate by current scheduler. Will be a list of float."""
208
+ return self._last_lr
209
+
210
+ def get_lr(self):
211
+ # Compute list of learning rates from self.epoch and self.batch and
212
+ # self.base_lrs; this must be overloaded by the user.
213
+ # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
214
+ raise NotImplementedError
215
+
216
+ def step_batch(self, batch: Optional[int] = None) -> None:
217
+ # Step the batch index, or just set it. If `batch` is specified, it
218
+ # must be the batch index from the start of training, i.e. summed over
219
+ # all epochs.
220
+ # You can call this in any order; if you don't provide 'batch', it should
221
+ # of course be called once per batch.
222
+ if batch is not None:
223
+ self.batch = batch
224
+ else:
225
+ self.batch = self.batch + 1
226
+ self._set_lrs()
227
+
228
+ def step_epoch(self, epoch: Optional[int] = None):
229
+ # Step the epoch index, or just set it. If you provide the 'epoch' arg,
230
+ # you should call this at the start of the epoch; if you don't provide the 'epoch'
231
+ # arg, you should call it at the end of the epoch.
232
+ if epoch is not None:
233
+ self.epoch = epoch
234
+ else:
235
+ self.epoch = self.epoch + 1
236
+ self._set_lrs()
237
+
238
+ def _set_lrs(self):
239
+ values = self.get_lr()
240
+ assert len(values) == len(self.optimizer.param_groups)
241
+
242
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
243
+ param_group, lr = data
244
+ param_group["lr"] = lr
245
+ self.print_lr(self.verbose, i, lr)
246
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
247
+
248
+ def print_lr(self, is_verbose, group, lr):
249
+ """Display the current learning rate."""
250
+ if is_verbose:
251
+ print(
252
+ f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
253
+ f" of group {group} to {lr:.4e}."
254
+ )
255
+
256
+
257
+ class Eden(LRScheduler):
258
+ """
259
+ Eden scheduler.
260
+ lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
261
+ (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25))
262
+
263
+ E.g. suggest initial-lr = 0.003 (passed to optimizer).
264
+
265
+ Args:
266
+ optimizer: the optimizer to change the learning rates on
267
+ lr_batches: the number of batches after which we start significantly
268
+ decreasing the learning rate, suggest 5000.
269
+ lr_epochs: the number of epochs after which we start significantly
270
+ decreasing the learning rate, suggest 6 if you plan to do e.g.
271
+ 20 to 40 epochs, but may need smaller number if dataset is huge
272
+ and you will do few epochs.
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ optimizer: Optimizer,
278
+ lr_batches: Union[int, float],
279
+ lr_epochs: Union[int, float],
280
+ verbose: bool = False,
281
+ ):
282
+ super(Eden, self).__init__(optimizer, verbose)
283
+ self.lr_batches = lr_batches
284
+ self.lr_epochs = lr_epochs
285
+
286
+ def get_lr(self):
287
+ factor = (
288
+ (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
289
+ ) ** -0.25 * (
290
+ ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
291
+ )
292
+ return [x * factor for x in self.base_lrs]
293
+
294
+
295
+ def _test_eden():
296
+ m = torch.nn.Linear(100, 100)
297
+ optim = Eve(m.parameters(), lr=0.003)
298
+
299
+ scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True)
300
+
301
+ for epoch in range(10):
302
+ scheduler.step_epoch(epoch) # sets epoch to `epoch`
303
+
304
+ for step in range(20):
305
+ x = torch.randn(200, 100).detach()
306
+ x.requires_grad = True
307
+ y = m(x)
308
+ dy = torch.randn(200, 100).detach()
309
+ f = (y * dy).sum()
310
+ f.backward()
311
+
312
+ optim.step()
313
+ scheduler.step_batch()
314
+ optim.zero_grad()
315
+ print("last lr = ", scheduler.get_last_lr())
316
+ print("state dict = ", scheduler.state_dict())
317
+
318
+
319
+ if __name__ == "__main__":
320
+ _test_eden()
err2020/conformer_ctc3/pretrained.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
3
+ # Mingshuang Luo,)
4
+ # Zengwei Yao)
5
+ #
6
+ # See ../../../../LICENSE for clarification regarding multiple authors
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+
21
+ """
22
+ Usage (for non-streaming mode):
23
+
24
+ (1) ctc-decoding
25
+ ./conformer_ctc3/pretrained.py \
26
+ --checkpoint conformer_ctc3/exp/pretrained.pt \
27
+ --bpe-model data/lang_bpe_500/bpe.model \
28
+ --method ctc-decoding \
29
+ --sample-rate 16000 \
30
+ test_wavs/1089-134686-0001.wav
31
+
32
+ (2) 1best
33
+ ./conformer_ctc3/pretrained.py \
34
+ --checkpoint conformer_ctc3/exp/pretrained.pt \
35
+ --HLG data/lang_bpe_500/HLG.pt \
36
+ --words-file data/lang_bpe_500/words.txt \
37
+ --method 1best \
38
+ --sample-rate 16000 \
39
+ test_wavs/1089-134686-0001.wav
40
+
41
+ (3) nbest-rescoring
42
+ ./conformer_ctc3/pretrained.py \
43
+ --checkpoint conformer_ctc3/exp/pretrained.pt \
44
+ --HLG data/lang_bpe_500/HLG.pt \
45
+ --words-file data/lang_bpe_500/words.txt \
46
+ --G data/lm/G_4_gram.pt \
47
+ --method nbest-rescoring \
48
+ --sample-rate 16000 \
49
+ test_wavs/1089-134686-0001.wav
50
+
51
+ (4) whole-lattice-rescoring
52
+ ./conformer_ctc3/pretrained.py \
53
+ --checkpoint conformer_ctc3/exp/pretrained.pt \
54
+ --HLG data/lang_bpe_500/HLG.pt \
55
+ --words-file data/lang_bpe_500/words.txt \
56
+ --G data/lm/G_4_gram.pt \
57
+ --method whole-lattice-rescoring \
58
+ --sample-rate 16000 \
59
+ test_wavs/1089-134686-0001.wav
60
+ """
61
+
62
+
63
+ import argparse
64
+ import logging
65
+ import math
66
+ from typing import List
67
+
68
+ import k2
69
+ import kaldifeat
70
+ import sentencepiece as spm
71
+ import torch
72
+ import torchaudio
73
+ from decode import get_decoding_params
74
+ from torch.nn.utils.rnn import pad_sequence
75
+ from train import add_model_arguments, get_ctc_model, get_params
76
+
77
+ from icefall.decode import (
78
+ get_lattice,
79
+ one_best_decoding,
80
+ rescore_with_n_best_list,
81
+ rescore_with_whole_lattice,
82
+ )
83
+ from icefall.utils import get_texts, str2bool
84
+
85
+
86
+ def get_parser():
87
+ parser = argparse.ArgumentParser(
88
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
89
+ )
90
+
91
+ parser.add_argument(
92
+ "--checkpoint",
93
+ type=str,
94
+ required=True,
95
+ help="Path to the checkpoint. "
96
+ "The checkpoint is assumed to be saved by "
97
+ "icefall.checkpoint.save_checkpoint().",
98
+ )
99
+
100
+ parser.add_argument(
101
+ "--words-file",
102
+ type=str,
103
+ help="""Path to words.txt.
104
+ Used only when method is not ctc-decoding.
105
+ """,
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--HLG",
110
+ type=str,
111
+ help="""Path to HLG.pt.
112
+ Used only when method is not ctc-decoding.
113
+ """,
114
+ )
115
+
116
+ parser.add_argument(
117
+ "--bpe-model",
118
+ type=str,
119
+ help="""Path to bpe.model.
120
+ Used only when method is ctc-decoding.
121
+ """,
122
+ )
123
+
124
+ parser.add_argument(
125
+ "--method",
126
+ type=str,
127
+ default="1best",
128
+ help="""Decoding method.
129
+ Possible values are:
130
+ (0) ctc-decoding - Use CTC decoding. It uses a sentence
131
+ piece model, i.e., lang_dir/bpe.model, to convert
132
+ word pieces to words. It needs neither a lexicon
133
+ nor an n-gram LM.
134
+ (1) 1best - Use the best path as decoding output. Only
135
+ the transformer encoder output is used for decoding.
136
+ We call it HLG decoding.
137
+ (2) nbest-rescoring. Extract n paths from the decoding lattice,
138
+ rescore them with an LM, the path with
139
+ the highest score is the decoding result.
140
+ We call it HLG decoding + n-gram LM rescoring.
141
+ (3) whole-lattice-rescoring - Use an LM to rescore the
142
+ decoding lattice and then use 1best to decode the
143
+ rescored lattice.
144
+ We call it HLG decoding + n-gram LM rescoring.
145
+ """,
146
+ )
147
+
148
+ parser.add_argument(
149
+ "--G",
150
+ type=str,
151
+ help="""An LM for rescoring.
152
+ Used only when method is
153
+ whole-lattice-rescoring or nbest-rescoring.
154
+ It's usually a 4-gram LM.
155
+ """,
156
+ )
157
+
158
+ parser.add_argument(
159
+ "--num-paths",
160
+ type=int,
161
+ default=100,
162
+ help="""
163
+ Used only when method is attention-decoder.
164
+ It specifies the size of n-best list.""",
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--ngram-lm-scale",
169
+ type=float,
170
+ default=1.3,
171
+ help="""
172
+ Used only when method is whole-lattice-rescoring and nbest-rescoring.
173
+ It specifies the scale for n-gram LM scores.
174
+ (Note: You need to tune it on a dataset.)
175
+ """,
176
+ )
177
+
178
+ parser.add_argument(
179
+ "--nbest-scale",
180
+ type=float,
181
+ default=0.5,
182
+ help="""
183
+ Used only when method is nbest-rescoring.
184
+ It specifies the scale for lattice.scores when
185
+ extracting n-best lists. A smaller value results in
186
+ more unique number of paths with the risk of missing
187
+ the best path.
188
+ """,
189
+ )
190
+
191
+ parser.add_argument(
192
+ "--num-classes",
193
+ type=int,
194
+ default=500,
195
+ help="""
196
+ Vocab size in the BPE model.
197
+ """,
198
+ )
199
+
200
+ parser.add_argument(
201
+ "--simulate-streaming",
202
+ type=str2bool,
203
+ default=False,
204
+ help="""Whether to simulate streaming in decoding, this is a good way to
205
+ test a streaming model.
206
+ """,
207
+ )
208
+
209
+ parser.add_argument(
210
+ "--decode-chunk-size",
211
+ type=int,
212
+ default=16,
213
+ help="The chunk size for decoding (in frames after subsampling)",
214
+ )
215
+
216
+ parser.add_argument(
217
+ "--left-context",
218
+ type=int,
219
+ default=64,
220
+ help="left context can be seen during decoding (in frames after subsampling)",
221
+ )
222
+
223
+ parser.add_argument(
224
+ "--sample-rate",
225
+ type=int,
226
+ default=16000,
227
+ help="The sample rate of the input sound file",
228
+ )
229
+
230
+ parser.add_argument(
231
+ "sound_files",
232
+ type=str,
233
+ nargs="+",
234
+ help="The input sound file(s) to transcribe. "
235
+ "Supported formats are those supported by torchaudio.load(). "
236
+ "For example, wav and flac are supported. "
237
+ "The sample rate has to be 16kHz.",
238
+ )
239
+
240
+ add_model_arguments(parser)
241
+
242
+ return parser
243
+
244
+
245
+ def read_sound_files(
246
+ filenames: List[str], expected_sample_rate: float
247
+ ) -> List[torch.Tensor]:
248
+ """Read a list of sound files into a list 1-D float32 torch tensors.
249
+ Args:
250
+ filenames:
251
+ A list of sound filenames.
252
+ expected_sample_rate:
253
+ The expected sample rate of the sound files.
254
+ Returns:
255
+ Return a list of 1-D float32 torch tensors.
256
+ """
257
+ ans = []
258
+ for f in filenames:
259
+ wave, sample_rate = torchaudio.load(f)
260
+ assert sample_rate == expected_sample_rate, (
261
+ f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
262
+ )
263
+ # We use only the first channel
264
+ ans.append(wave[0])
265
+ return ans
266
+
267
+
268
+ def main():
269
+ parser = get_parser()
270
+ args = parser.parse_args()
271
+
272
+ params = get_params()
273
+ # add decoding params
274
+ params.update(get_decoding_params())
275
+ params.update(vars(args))
276
+ params.vocab_size = params.num_classes
277
+
278
+ if params.simulate_streaming:
279
+ assert (
280
+ params.causal_convolution
281
+ ), "Decoding in streaming requires causal convolution"
282
+
283
+ logging.info(f"{params}")
284
+
285
+ device = torch.device("cpu")
286
+ if torch.cuda.is_available():
287
+ device = torch.device("cuda", 0)
288
+
289
+ logging.info(f"device: {device}")
290
+
291
+ logging.info("About to create model")
292
+ model = get_ctc_model(params)
293
+
294
+ num_param = sum([p.numel() for p in model.parameters()])
295
+ logging.info(f"Number of model parameters: {num_param}")
296
+
297
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
298
+ model.load_state_dict(checkpoint["model"], strict=False)
299
+ model.to(device)
300
+ model.eval()
301
+
302
+ logging.info("Constructing Fbank computer")
303
+ opts = kaldifeat.FbankOptions()
304
+ opts.device = device
305
+ opts.frame_opts.dither = 0
306
+ opts.frame_opts.snip_edges = False
307
+ opts.frame_opts.samp_freq = params.sample_rate
308
+ opts.mel_opts.num_bins = params.feature_dim
309
+
310
+ fbank = kaldifeat.Fbank(opts)
311
+
312
+ logging.info(f"Reading sound files: {params.sound_files}")
313
+ waves = read_sound_files(
314
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
315
+ )
316
+ waves = [w.to(device) for w in waves]
317
+
318
+ logging.info("Decoding started")
319
+ features = fbank(waves)
320
+ feature_lengths = [f.size(0) for f in features]
321
+
322
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
323
+ feature_lengths = torch.tensor(feature_lengths, device=device)
324
+
325
+ # model forward
326
+ if params.simulate_streaming:
327
+ encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
328
+ x=features,
329
+ x_lens=feature_lengths,
330
+ chunk_size=params.decode_chunk_size,
331
+ left_context=params.left_context,
332
+ simulate_streaming=True,
333
+ )
334
+ else:
335
+ encoder_out, encoder_out_lens = model.encoder(
336
+ x=features, x_lens=feature_lengths
337
+ )
338
+ nnet_output = model.get_ctc_output(encoder_out)
339
+
340
+ batch_size = nnet_output.shape[0]
341
+ supervision_segments = torch.tensor(
342
+ [
343
+ [i, 0, feature_lengths[i] // params.subsampling_factor]
344
+ for i in range(batch_size)
345
+ ],
346
+ dtype=torch.int32,
347
+ )
348
+
349
+ if params.method == "ctc-decoding":
350
+ logging.info("Use CTC decoding")
351
+ bpe_model = spm.SentencePieceProcessor()
352
+ bpe_model.load(params.bpe_model)
353
+ max_token_id = params.num_classes - 1
354
+
355
+ H = k2.ctc_topo(
356
+ max_token=max_token_id,
357
+ modified=False,
358
+ device=device,
359
+ )
360
+
361
+ lattice = get_lattice(
362
+ nnet_output=nnet_output,
363
+ decoding_graph=H,
364
+ supervision_segments=supervision_segments,
365
+ search_beam=params.search_beam,
366
+ output_beam=params.output_beam,
367
+ min_active_states=params.min_active_states,
368
+ max_active_states=params.max_active_states,
369
+ subsampling_factor=params.subsampling_factor,
370
+ )
371
+
372
+ best_path = one_best_decoding(
373
+ lattice=lattice, use_double_scores=params.use_double_scores
374
+ )
375
+ token_ids = get_texts(best_path)
376
+ hyps = bpe_model.decode(token_ids)
377
+ hyps = [s.split() for s in hyps]
378
+ elif params.method in [
379
+ "1best",
380
+ "nbest-rescoring",
381
+ "whole-lattice-rescoring",
382
+ ]:
383
+ logging.info(f"Loading HLG from {params.HLG}")
384
+ HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
385
+ HLG = HLG.to(device)
386
+ if not hasattr(HLG, "lm_scores"):
387
+ # For whole-lattice-rescoring and attention-decoder
388
+ HLG.lm_scores = HLG.scores.clone()
389
+
390
+ if params.method in [
391
+ "nbest-rescoring",
392
+ "whole-lattice-rescoring",
393
+ ]:
394
+ logging.info(f"Loading G from {params.G}")
395
+ G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
396
+ G = G.to(device)
397
+ if params.method == "whole-lattice-rescoring":
398
+ # Add epsilon self-loops to G as we will compose
399
+ # it with the whole lattice later
400
+ G = k2.add_epsilon_self_loops(G)
401
+ G = k2.arc_sort(G)
402
+
403
+ # G.lm_scores is used to replace HLG.lm_scores during
404
+ # LM rescoring.
405
+ G.lm_scores = G.scores.clone()
406
+
407
+ lattice = get_lattice(
408
+ nnet_output=nnet_output,
409
+ decoding_graph=HLG,
410
+ supervision_segments=supervision_segments,
411
+ search_beam=params.search_beam,
412
+ output_beam=params.output_beam,
413
+ min_active_states=params.min_active_states,
414
+ max_active_states=params.max_active_states,
415
+ subsampling_factor=params.subsampling_factor,
416
+ )
417
+
418
+ if params.method == "1best":
419
+ logging.info("Use HLG decoding")
420
+ best_path = one_best_decoding(
421
+ lattice=lattice, use_double_scores=params.use_double_scores
422
+ )
423
+ if params.method == "nbest-rescoring":
424
+ logging.info("Use HLG decoding + LM rescoring")
425
+ best_path_dict = rescore_with_n_best_list(
426
+ lattice=lattice,
427
+ G=G,
428
+ num_paths=params.num_paths,
429
+ lm_scale_list=[params.ngram_lm_scale],
430
+ nbest_scale=params.nbest_scale,
431
+ )
432
+ best_path = next(iter(best_path_dict.values()))
433
+ elif params.method == "whole-lattice-rescoring":
434
+ logging.info("Use HLG decoding + LM rescoring")
435
+ best_path_dict = rescore_with_whole_lattice(
436
+ lattice=lattice,
437
+ G_with_epsilon_loops=G,
438
+ lm_scale_list=[params.ngram_lm_scale],
439
+ )
440
+ best_path = next(iter(best_path_dict.values()))
441
+
442
+ hyps = get_texts(best_path)
443
+ word_sym_table = k2.SymbolTable.from_file(params.words_file)
444
+ hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
445
+ else:
446
+ raise ValueError(f"Unsupported decoding method: {params.method}")
447
+
448
+ s = "\n"
449
+ for filename, hyp in zip(params.sound_files, hyps):
450
+ words = " ".join(hyp)
451
+ s += f"{filename}:\n{words}\n\n"
452
+ logging.info(s)
453
+
454
+ logging.info("Decoding Done")
455
+
456
+
457
+ if __name__ == "__main__":
458
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
459
+
460
+ logging.basicConfig(format=formatter, level=logging.INFO)
461
+ main()
err2020/conformer_ctc3/scaling.py ADDED
@@ -0,0 +1,1015 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import collections
19
+ import random
20
+ from itertools import repeat
21
+ from typing import Optional, Tuple
22
+
23
+ import torch
24
+ import torch.backends.cudnn.rnn as rnn
25
+ import torch.nn as nn
26
+ from torch import _VF, Tensor
27
+
28
+ from icefall.utils import is_jit_tracing
29
+
30
+
31
+ def _ntuple(n):
32
+ def parse(x):
33
+ if isinstance(x, collections.Iterable):
34
+ return x
35
+ return tuple(repeat(x, n))
36
+
37
+ return parse
38
+
39
+
40
+ _single = _ntuple(1)
41
+ _pair = _ntuple(2)
42
+
43
+
44
+ class ActivationBalancerFunction(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(
47
+ ctx,
48
+ x: Tensor,
49
+ channel_dim: int,
50
+ min_positive: float, # e.g. 0.05
51
+ max_positive: float, # e.g. 0.95
52
+ max_factor: float, # e.g. 0.01
53
+ min_abs: float, # e.g. 0.2
54
+ max_abs: float, # e.g. 100.0
55
+ ) -> Tensor:
56
+ if x.requires_grad:
57
+ if channel_dim < 0:
58
+ channel_dim += x.ndim
59
+
60
+ # sum_dims = [d for d in range(x.ndim) if d != channel_dim]
61
+ # The above line is not torch scriptable for torch 1.6.0
62
+ # torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa
63
+ sum_dims = []
64
+ for d in range(x.ndim):
65
+ if d != channel_dim:
66
+ sum_dims.append(d)
67
+
68
+ xgt0 = x > 0
69
+ proportion_positive = torch.mean(
70
+ xgt0.to(x.dtype), dim=sum_dims, keepdim=True
71
+ )
72
+ factor1 = (
73
+ (min_positive - proportion_positive).relu()
74
+ * (max_factor / min_positive)
75
+ if min_positive != 0.0
76
+ else 0.0
77
+ )
78
+ factor2 = (
79
+ (proportion_positive - max_positive).relu()
80
+ * (max_factor / (max_positive - 1.0))
81
+ if max_positive != 1.0
82
+ else 0.0
83
+ )
84
+ factor = factor1 + factor2
85
+ if isinstance(factor, float):
86
+ factor = torch.zeros_like(proportion_positive)
87
+
88
+ mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
89
+ below_threshold = mean_abs < min_abs
90
+ above_threshold = mean_abs > max_abs
91
+
92
+ ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
93
+ ctx.max_factor = max_factor
94
+ ctx.sum_dims = sum_dims
95
+ return x
96
+
97
+ @staticmethod
98
+ def backward(
99
+ ctx, x_grad: Tensor
100
+ ) -> Tuple[Tensor, None, None, None, None, None, None]:
101
+ factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
102
+ dtype = x_grad.dtype
103
+ scale_factor = (
104
+ (below_threshold.to(dtype) - above_threshold.to(dtype))
105
+ * (xgt0.to(dtype) - 0.5)
106
+ * (ctx.max_factor * 2.0)
107
+ )
108
+
109
+ neg_delta_grad = x_grad.abs() * (factor + scale_factor)
110
+ return x_grad - neg_delta_grad, None, None, None, None, None, None
111
+
112
+
113
+ class GradientFilterFunction(torch.autograd.Function):
114
+ @staticmethod
115
+ def forward(
116
+ ctx,
117
+ x: Tensor,
118
+ batch_dim: int, # e.g., 1
119
+ threshold: float, # e.g., 10.0
120
+ *params: Tensor, # module parameters
121
+ ) -> Tuple[Tensor, ...]:
122
+ if x.requires_grad:
123
+ if batch_dim < 0:
124
+ batch_dim += x.ndim
125
+ ctx.batch_dim = batch_dim
126
+ ctx.threshold = threshold
127
+ return (x,) + params
128
+
129
+ @staticmethod
130
+ def backward(
131
+ ctx,
132
+ x_grad: Tensor,
133
+ *param_grads: Tensor,
134
+ ) -> Tuple[Tensor, ...]:
135
+ eps = 1.0e-20
136
+ dim = ctx.batch_dim
137
+ norm_dims = [d for d in range(x_grad.ndim) if d != dim]
138
+ norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt()
139
+ median_norm = norm_of_batch.median()
140
+
141
+ cutoff = median_norm * ctx.threshold
142
+ inv_mask = (cutoff + norm_of_batch) / (cutoff + eps)
143
+ mask = 1.0 / (inv_mask + eps)
144
+ x_grad = x_grad * mask
145
+
146
+ avg_mask = 1.0 / (inv_mask.mean() + eps)
147
+ param_grads = [avg_mask * g for g in param_grads]
148
+
149
+ return (x_grad, None, None) + tuple(param_grads)
150
+
151
+
152
+ class GradientFilter(torch.nn.Module):
153
+ """This is used to filter out elements that have extremely large gradients
154
+ in batch and the module parameters with soft masks.
155
+
156
+ Args:
157
+ batch_dim (int):
158
+ The batch dimension.
159
+ threshold (float):
160
+ For each element in batch, its gradient will be
161
+ filtered out if the gradient norm is larger than
162
+ `grad_norm_threshold * median`, where `median` is the median
163
+ value of gradient norms of all elememts in batch.
164
+ """
165
+
166
+ def __init__(self, batch_dim: int = 1, threshold: float = 10.0):
167
+ super(GradientFilter, self).__init__()
168
+ self.batch_dim = batch_dim
169
+ self.threshold = threshold
170
+
171
+ def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]:
172
+ if torch.jit.is_scripting() or is_jit_tracing():
173
+ return (x,) + params
174
+ else:
175
+ return GradientFilterFunction.apply(
176
+ x,
177
+ self.batch_dim,
178
+ self.threshold,
179
+ *params,
180
+ )
181
+
182
+
183
+ class BasicNorm(torch.nn.Module):
184
+ """
185
+ This is intended to be a simpler, and hopefully cheaper, replacement for
186
+ LayerNorm. The observation this is based on, is that Transformer-type
187
+ networks, especially with pre-norm, sometimes seem to set one of the
188
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
189
+ the LayerNorm because the output magnitude is then not strongly dependent
190
+ on the other (useful) features. Presumably the weight and bias of the
191
+ LayerNorm are required to allow it to do this.
192
+
193
+ So the idea is to introduce this large constant value as an explicit
194
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
195
+ doesn't have to do this trick. We make the "eps" learnable.
196
+
197
+ Args:
198
+ num_channels: the number of channels, e.g. 512.
199
+ channel_dim: the axis/dimension corresponding to the channel,
200
+ interprted as an offset from the input's ndim if negative.
201
+ shis is NOT the num_channels; it should typically be one of
202
+ {-2, -1, 0, 1, 2, 3}.
203
+ eps: the initial "epsilon" that we add as ballast in:
204
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
205
+ Note: our epsilon is actually large, but we keep the name
206
+ to indicate the connection with conventional LayerNorm.
207
+ learn_eps: if true, we learn epsilon; if false, we keep it
208
+ at the initial value.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ num_channels: int,
214
+ channel_dim: int = -1, # CAUTION: see documentation.
215
+ eps: float = 0.25,
216
+ learn_eps: bool = True,
217
+ ) -> None:
218
+ super(BasicNorm, self).__init__()
219
+ self.num_channels = num_channels
220
+ self.channel_dim = channel_dim
221
+ if learn_eps:
222
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
223
+ else:
224
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
225
+
226
+ def forward(self, x: Tensor) -> Tensor:
227
+ if not is_jit_tracing():
228
+ assert x.shape[self.channel_dim] == self.num_channels
229
+ scales = (
230
+ torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp()
231
+ ) ** -0.5
232
+ return x * scales
233
+
234
+
235
+ class ScaledLinear(nn.Linear):
236
+ """
237
+ A modified version of nn.Linear where the parameters are scaled before
238
+ use, via:
239
+ weight = self.weight * self.weight_scale.exp()
240
+ bias = self.bias * self.bias_scale.exp()
241
+
242
+ Args:
243
+ Accepts the standard args and kwargs that nn.Linear accepts
244
+ e.g. in_features, out_features, bias=False.
245
+
246
+ initial_scale: you can override this if you want to increase
247
+ or decrease the initial magnitude of the module's output
248
+ (affects the initialization of weight_scale and bias_scale).
249
+ Another option, if you want to do something like this, is
250
+ to re-initialize the parameters.
251
+ initial_speed: this affects how fast the parameter will
252
+ learn near the start of training; you can set it to a
253
+ value less than one if you suspect that a module
254
+ is contributing to instability near the start of training.
255
+ Nnote: regardless of the use of this option, it's best to
256
+ use schedulers like Noam that have a warm-up period.
257
+ Alternatively you can set it to more than 1 if you want it to
258
+ initially train faster. Must be greater than 0.
259
+ """
260
+
261
+ def __init__(
262
+ self,
263
+ *args,
264
+ initial_scale: float = 1.0,
265
+ initial_speed: float = 1.0,
266
+ **kwargs,
267
+ ):
268
+ super(ScaledLinear, self).__init__(*args, **kwargs)
269
+ initial_scale = torch.tensor(initial_scale).log()
270
+ self.weight_scale = nn.Parameter(initial_scale.clone().detach())
271
+ if self.bias is not None:
272
+ self.bias_scale = nn.Parameter(initial_scale.clone().detach())
273
+ else:
274
+ self.register_parameter("bias_scale", None)
275
+
276
+ self._reset_parameters(
277
+ initial_speed
278
+ ) # Overrides the reset_parameters in nn.Linear
279
+
280
+ def _reset_parameters(self, initial_speed: float):
281
+ std = 0.1 / initial_speed
282
+ a = (3**0.5) * std
283
+ nn.init.uniform_(self.weight, -a, a)
284
+ if self.bias is not None:
285
+ nn.init.constant_(self.bias, 0.0)
286
+ fan_in = self.weight.shape[1] * self.weight[0][0].numel()
287
+ scale = fan_in**-0.5 # 1/sqrt(fan_in)
288
+ with torch.no_grad():
289
+ self.weight_scale += torch.tensor(scale / std).log()
290
+
291
+ def get_weight(self):
292
+ return self.weight * self.weight_scale.exp()
293
+
294
+ def get_bias(self):
295
+ if self.bias is None or self.bias_scale is None:
296
+ return None
297
+ else:
298
+ return self.bias * self.bias_scale.exp()
299
+
300
+ def forward(self, input: Tensor) -> Tensor:
301
+ return torch.nn.functional.linear(input, self.get_weight(), self.get_bias())
302
+
303
+
304
+ class ScaledConv1d(nn.Conv1d):
305
+ # See docs for ScaledLinear
306
+ def __init__(
307
+ self,
308
+ *args,
309
+ initial_scale: float = 1.0,
310
+ initial_speed: float = 1.0,
311
+ **kwargs,
312
+ ):
313
+ super(ScaledConv1d, self).__init__(*args, **kwargs)
314
+ initial_scale = torch.tensor(initial_scale).log()
315
+
316
+ self.bias_scale: Optional[nn.Parameter] # for torchscript
317
+
318
+ self.weight_scale = nn.Parameter(initial_scale.clone().detach())
319
+ if self.bias is not None:
320
+ self.bias_scale = nn.Parameter(initial_scale.clone().detach())
321
+ else:
322
+ self.register_parameter("bias_scale", None)
323
+ self._reset_parameters(
324
+ initial_speed
325
+ ) # Overrides the reset_parameters in base class
326
+
327
+ def _reset_parameters(self, initial_speed: float):
328
+ std = 0.1 / initial_speed
329
+ a = (3**0.5) * std
330
+ nn.init.uniform_(self.weight, -a, a)
331
+ if self.bias is not None:
332
+ nn.init.constant_(self.bias, 0.0)
333
+ fan_in = self.weight.shape[1] * self.weight[0][0].numel()
334
+ scale = fan_in**-0.5 # 1/sqrt(fan_in)
335
+ with torch.no_grad():
336
+ self.weight_scale += torch.tensor(scale / std).log()
337
+
338
+ def get_weight(self):
339
+ return self.weight * self.weight_scale.exp()
340
+
341
+ def get_bias(self):
342
+ bias = self.bias
343
+ bias_scale = self.bias_scale
344
+ if bias is None or bias_scale is None:
345
+ return None
346
+ else:
347
+ return bias * bias_scale.exp()
348
+
349
+ def forward(self, input: Tensor) -> Tensor:
350
+ F = torch.nn.functional
351
+ if self.padding_mode != "zeros":
352
+ return F.conv1d(
353
+ F.pad(
354
+ input,
355
+ self._reversed_padding_repeated_twice,
356
+ mode=self.padding_mode,
357
+ ),
358
+ self.get_weight(),
359
+ self.get_bias(),
360
+ self.stride,
361
+ (0,),
362
+ self.dilation,
363
+ self.groups,
364
+ )
365
+ return F.conv1d(
366
+ input,
367
+ self.get_weight(),
368
+ self.get_bias(),
369
+ self.stride,
370
+ self.padding,
371
+ self.dilation,
372
+ self.groups,
373
+ )
374
+
375
+
376
+ class ScaledConv2d(nn.Conv2d):
377
+ # See docs for ScaledLinear
378
+ def __init__(
379
+ self,
380
+ *args,
381
+ initial_scale: float = 1.0,
382
+ initial_speed: float = 1.0,
383
+ **kwargs,
384
+ ):
385
+ super(ScaledConv2d, self).__init__(*args, **kwargs)
386
+ initial_scale = torch.tensor(initial_scale).log()
387
+ self.weight_scale = nn.Parameter(initial_scale.clone().detach())
388
+ if self.bias is not None:
389
+ self.bias_scale = nn.Parameter(initial_scale.clone().detach())
390
+ else:
391
+ self.register_parameter("bias_scale", None)
392
+ self._reset_parameters(
393
+ initial_speed
394
+ ) # Overrides the reset_parameters in base class
395
+
396
+ def _reset_parameters(self, initial_speed: float):
397
+ std = 0.1 / initial_speed
398
+ a = (3**0.5) * std
399
+ nn.init.uniform_(self.weight, -a, a)
400
+ if self.bias is not None:
401
+ nn.init.constant_(self.bias, 0.0)
402
+ fan_in = self.weight.shape[1] * self.weight[0][0].numel()
403
+ scale = fan_in**-0.5 # 1/sqrt(fan_in)
404
+ with torch.no_grad():
405
+ self.weight_scale += torch.tensor(scale / std).log()
406
+
407
+ def get_weight(self):
408
+ return self.weight * self.weight_scale.exp()
409
+
410
+ def get_bias(self):
411
+ # see https://github.com/pytorch/pytorch/issues/24135
412
+ bias = self.bias
413
+ bias_scale = self.bias_scale
414
+ if bias is None or bias_scale is None:
415
+ return None
416
+ else:
417
+ return bias * bias_scale.exp()
418
+
419
+ def _conv_forward(self, input, weight):
420
+ F = torch.nn.functional
421
+ if self.padding_mode != "zeros":
422
+ return F.conv2d(
423
+ F.pad(
424
+ input,
425
+ self._reversed_padding_repeated_twice,
426
+ mode=self.padding_mode,
427
+ ),
428
+ weight,
429
+ self.get_bias(),
430
+ self.stride,
431
+ (0, 0),
432
+ self.dilation,
433
+ self.groups,
434
+ )
435
+ return F.conv2d(
436
+ input,
437
+ weight,
438
+ self.get_bias(),
439
+ self.stride,
440
+ self.padding,
441
+ self.dilation,
442
+ self.groups,
443
+ )
444
+
445
+ def forward(self, input: Tensor) -> Tensor:
446
+ return self._conv_forward(input, self.get_weight())
447
+
448
+
449
+ class ScaledLSTM(nn.LSTM):
450
+ # See docs for ScaledLinear.
451
+ # This class implements LSTM with scaling mechanism, using `torch._VF.lstm`
452
+ # Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
453
+ def __init__(
454
+ self,
455
+ *args,
456
+ initial_scale: float = 1.0,
457
+ initial_speed: float = 1.0,
458
+ grad_norm_threshold: float = 10.0,
459
+ **kwargs,
460
+ ):
461
+ if "bidirectional" in kwargs:
462
+ assert kwargs["bidirectional"] is False
463
+ super(ScaledLSTM, self).__init__(*args, **kwargs)
464
+ initial_scale = torch.tensor(initial_scale).log()
465
+ self._scales_names = []
466
+ self._scales = []
467
+ for name in self._flat_weights_names:
468
+ scale_name = name + "_scale"
469
+ self._scales_names.append(scale_name)
470
+ param = nn.Parameter(initial_scale.clone().detach())
471
+ setattr(self, scale_name, param)
472
+ self._scales.append(param)
473
+
474
+ self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold)
475
+
476
+ self._reset_parameters(
477
+ initial_speed
478
+ ) # Overrides the reset_parameters in base class
479
+
480
+ def _reset_parameters(self, initial_speed: float):
481
+ std = 0.1 / initial_speed
482
+ a = (3**0.5) * std
483
+ scale = self.hidden_size**-0.5
484
+ v = scale / std
485
+ for idx, name in enumerate(self._flat_weights_names):
486
+ if "weight" in name:
487
+ nn.init.uniform_(self._flat_weights[idx], -a, a)
488
+ with torch.no_grad():
489
+ self._scales[idx] += torch.tensor(v).log()
490
+ elif "bias" in name:
491
+ nn.init.constant_(self._flat_weights[idx], 0.0)
492
+
493
+ def _flatten_parameters(self, flat_weights) -> None:
494
+ """Resets parameter data pointer so that they can use faster code paths.
495
+
496
+ Right now, this works only if the module is on the GPU and cuDNN is enabled.
497
+ Otherwise, it's a no-op.
498
+
499
+ This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
500
+ """
501
+ # Short-circuits if _flat_weights is only partially instantiated
502
+ if len(flat_weights) != len(self._flat_weights_names):
503
+ return
504
+
505
+ for w in flat_weights:
506
+ if not isinstance(w, Tensor):
507
+ return
508
+ # Short-circuits if any tensor in flat_weights is not acceptable to cuDNN
509
+ # or the tensors in flat_weights are of different dtypes
510
+
511
+ first_fw = flat_weights[0]
512
+ dtype = first_fw.dtype
513
+ for fw in flat_weights:
514
+ if (
515
+ not isinstance(fw.data, Tensor)
516
+ or not (fw.data.dtype == dtype)
517
+ or not fw.data.is_cuda
518
+ or not torch.backends.cudnn.is_acceptable(fw.data)
519
+ ):
520
+ return
521
+
522
+ # If any parameters alias, we fall back to the slower, copying code path. This is
523
+ # a sufficient check, because overlapping parameter buffers that don't completely
524
+ # alias would break the assumptions of the uniqueness check in
525
+ # Module.named_parameters().
526
+ unique_data_ptrs = set(p.data_ptr() for p in flat_weights)
527
+ if len(unique_data_ptrs) != len(flat_weights):
528
+ return
529
+
530
+ with torch.cuda.device_of(first_fw):
531
+
532
+ # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
533
+ # an inplace operation on self._flat_weights
534
+ with torch.no_grad():
535
+ if torch._use_cudnn_rnn_flatten_weight():
536
+ num_weights = 4 if self.bias else 2
537
+ if self.proj_size > 0:
538
+ num_weights += 1
539
+ torch._cudnn_rnn_flatten_weight(
540
+ flat_weights,
541
+ num_weights,
542
+ self.input_size,
543
+ rnn.get_cudnn_mode(self.mode),
544
+ self.hidden_size,
545
+ self.proj_size,
546
+ self.num_layers,
547
+ self.batch_first,
548
+ bool(self.bidirectional),
549
+ )
550
+
551
+ def _get_flat_weights(self):
552
+ """Get scaled weights, and resets their data pointer."""
553
+ flat_weights = []
554
+ for idx in range(len(self._flat_weights_names)):
555
+ flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp())
556
+ self._flatten_parameters(flat_weights)
557
+ return flat_weights
558
+
559
+ def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None):
560
+ # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
561
+ # The change for calling `_VF.lstm()` is:
562
+ # self._flat_weights -> self._get_flat_weights()
563
+ if hx is None:
564
+ h_zeros = torch.zeros(
565
+ self.num_layers,
566
+ input.size(1),
567
+ self.proj_size if self.proj_size > 0 else self.hidden_size,
568
+ dtype=input.dtype,
569
+ device=input.device,
570
+ )
571
+ c_zeros = torch.zeros(
572
+ self.num_layers,
573
+ input.size(1),
574
+ self.hidden_size,
575
+ dtype=input.dtype,
576
+ device=input.device,
577
+ )
578
+ hx = (h_zeros, c_zeros)
579
+
580
+ self.check_forward_args(input, hx, None)
581
+
582
+ flat_weights = self._get_flat_weights()
583
+ input, *flat_weights = self.grad_filter(input, *flat_weights)
584
+
585
+ result = _VF.lstm(
586
+ input,
587
+ hx,
588
+ flat_weights,
589
+ self.bias,
590
+ self.num_layers,
591
+ self.dropout,
592
+ self.training,
593
+ self.bidirectional,
594
+ self.batch_first,
595
+ )
596
+
597
+ output = result[0]
598
+ hidden = result[1:]
599
+ return output, hidden
600
+
601
+
602
+ class ActivationBalancer(torch.nn.Module):
603
+ """
604
+ Modifies the backpropped derivatives of a function to try to encourage, for
605
+ each channel, that it is positive at least a proportion `threshold` of the
606
+ time. It does this by multiplying negative derivative values by up to
607
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
608
+ interpolated from 1 at the threshold to those extremal values when none
609
+ of the inputs are positive.
610
+
611
+
612
+ Args:
613
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
614
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
615
+ min_positive: the minimum, per channel, of the proportion of the time
616
+ that (x > 0), below which we start to modify the derivatives.
617
+ max_positive: the maximum, per channel, of the proportion of the time
618
+ that (x > 0), above which we start to modify the derivatives.
619
+ max_factor: the maximum factor by which we modify the derivatives for
620
+ either the sign constraint or the magnitude constraint;
621
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
622
+ values in the range [0.98..1.02].
623
+ min_abs: the minimum average-absolute-value per channel, which
624
+ we allow, before we start to modify the derivatives to prevent
625
+ this.
626
+ max_abs: the maximum average-absolute-value per channel, which
627
+ we allow, before we start to modify the derivatives to prevent
628
+ this.
629
+ balance_prob: the probability to apply the ActivationBalancer.
630
+ """
631
+
632
+ def __init__(
633
+ self,
634
+ channel_dim: int,
635
+ min_positive: float = 0.05,
636
+ max_positive: float = 0.95,
637
+ max_factor: float = 0.01,
638
+ min_abs: float = 0.2,
639
+ max_abs: float = 100.0,
640
+ balance_prob: float = 0.25,
641
+ ):
642
+ super(ActivationBalancer, self).__init__()
643
+ self.channel_dim = channel_dim
644
+ self.min_positive = min_positive
645
+ self.max_positive = max_positive
646
+ self.max_factor = max_factor
647
+ self.min_abs = min_abs
648
+ self.max_abs = max_abs
649
+ assert 0 < balance_prob <= 1, balance_prob
650
+ self.balance_prob = balance_prob
651
+
652
+ def forward(self, x: Tensor) -> Tensor:
653
+ if random.random() >= self.balance_prob:
654
+ return x
655
+
656
+ return ActivationBalancerFunction.apply(
657
+ x,
658
+ self.channel_dim,
659
+ self.min_positive,
660
+ self.max_positive,
661
+ self.max_factor / self.balance_prob,
662
+ self.min_abs,
663
+ self.max_abs,
664
+ )
665
+
666
+
667
+ class DoubleSwishFunction(torch.autograd.Function):
668
+ """
669
+ double_swish(x) = x * torch.sigmoid(x-1)
670
+ This is a definition, originally motivated by its close numerical
671
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
672
+
673
+ Memory-efficient derivative computation:
674
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
675
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
676
+ Now, s'(x) = s(x) * (1-s(x)).
677
+ double_swish'(x) = x * s'(x) + s(x).
678
+ = x * s(x) * (1-s(x)) + s(x).
679
+ = double_swish(x) * (1-s(x)) + s(x)
680
+ ... so we just need to remember s(x) but not x itself.
681
+ """
682
+
683
+ @staticmethod
684
+ def forward(ctx, x: Tensor) -> Tensor:
685
+ x = x.detach()
686
+ s = torch.sigmoid(x - 1.0)
687
+ y = x * s
688
+ ctx.save_for_backward(s, y)
689
+ return y
690
+
691
+ @staticmethod
692
+ def backward(ctx, y_grad: Tensor) -> Tensor:
693
+ s, y = ctx.saved_tensors
694
+ return (y * (1 - s) + s) * y_grad
695
+
696
+
697
+ class DoubleSwish(torch.nn.Module):
698
+ def forward(self, x: Tensor) -> Tensor:
699
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
700
+ that we approximate closely with x * sigmoid(x-1).
701
+ """
702
+ if torch.jit.is_scripting() or is_jit_tracing():
703
+ return x * torch.sigmoid(x - 1.0)
704
+ else:
705
+ return DoubleSwishFunction.apply(x)
706
+
707
+
708
+ class ScaledEmbedding(nn.Module):
709
+ r"""This is a modified version of nn.Embedding that introduces a learnable scale
710
+ on the parameters. Note: due to how we initialize it, it's best used with
711
+ schedulers like Noam that have a warmup period.
712
+
713
+ It is a simple lookup table that stores embeddings of a fixed dictionary and size.
714
+
715
+ This module is often used to store word embeddings and retrieve them using indices.
716
+ The input to the module is a list of indices, and the output is the corresponding
717
+ word embeddings.
718
+
719
+ Args:
720
+ num_embeddings (int): size of the dictionary of embeddings
721
+ embedding_dim (int): the size of each embedding vector
722
+ padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
723
+ (initialized to zeros) whenever it encounters the index.
724
+ scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
725
+ the words in the mini-batch. Default ``False``.
726
+ sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
727
+ See Notes for more details regarding sparse gradients.
728
+
729
+ initial_speed (float, optional): This affects how fast the parameter will
730
+ learn near the start of training; you can set it to a value less than
731
+ one if you suspect that a module is contributing to instability near
732
+ the start of training. Note: regardless of the use of this option,
733
+ it's best to use schedulers like Noam that have a warm-up period.
734
+ Alternatively you can set it to more than 1 if you want it to
735
+ initially train faster. Must be greater than 0.
736
+
737
+
738
+ Attributes:
739
+ weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
740
+ initialized from :math:`\mathcal{N}(0, 1)`
741
+
742
+ Shape:
743
+ - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
744
+ - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
745
+
746
+ .. note::
747
+ Keep in mind that only a limited number of optimizers support
748
+ sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
749
+ :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
750
+
751
+ .. note::
752
+ With :attr:`padding_idx` set, the embedding vector at
753
+ :attr:`padding_idx` is initialized to all zeros. However, note that this
754
+ vector can be modified afterwards, e.g., using a customized
755
+ initialization method, and thus changing the vector used to pad the
756
+ output. The gradient for this vector from :class:`~torch.nn.Embedding`
757
+ is always zero.
758
+
759
+ Examples::
760
+
761
+ >>> # an Embedding module containing 10 tensors of size 3
762
+ >>> embedding = nn.Embedding(10, 3)
763
+ >>> # a batch of 2 samples of 4 indices each
764
+ >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
765
+ >>> embedding(input)
766
+ tensor([[[-0.0251, -1.6902, 0.7172],
767
+ [-0.6431, 0.0748, 0.6969],
768
+ [ 1.4970, 1.3448, -0.9685],
769
+ [-0.3677, -2.7265, -0.1685]],
770
+
771
+ [[ 1.4970, 1.3448, -0.9685],
772
+ [ 0.4362, -0.4004, 0.9400],
773
+ [-0.6431, 0.0748, 0.6969],
774
+ [ 0.9124, -2.3616, 1.1151]]])
775
+
776
+
777
+ >>> # example with padding_idx
778
+ >>> embedding = nn.Embedding(10, 3, padding_idx=0)
779
+ >>> input = torch.LongTensor([[0,2,0,5]])
780
+ >>> embedding(input)
781
+ tensor([[[ 0.0000, 0.0000, 0.0000],
782
+ [ 0.1535, -2.0309, 0.9315],
783
+ [ 0.0000, 0.0000, 0.0000],
784
+ [-0.1655, 0.9897, 0.0635]]])
785
+
786
+ """
787
+ __constants__ = [
788
+ "num_embeddings",
789
+ "embedding_dim",
790
+ "padding_idx",
791
+ "scale_grad_by_freq",
792
+ "sparse",
793
+ ]
794
+
795
+ num_embeddings: int
796
+ embedding_dim: int
797
+ padding_idx: int
798
+ scale_grad_by_freq: bool
799
+ weight: Tensor
800
+ sparse: bool
801
+
802
+ def __init__(
803
+ self,
804
+ num_embeddings: int,
805
+ embedding_dim: int,
806
+ padding_idx: Optional[int] = None,
807
+ scale_grad_by_freq: bool = False,
808
+ sparse: bool = False,
809
+ initial_speed: float = 1.0,
810
+ ) -> None:
811
+ super(ScaledEmbedding, self).__init__()
812
+ self.num_embeddings = num_embeddings
813
+ self.embedding_dim = embedding_dim
814
+ if padding_idx is not None:
815
+ if padding_idx > 0:
816
+ assert (
817
+ padding_idx < self.num_embeddings
818
+ ), "Padding_idx must be within num_embeddings"
819
+ elif padding_idx < 0:
820
+ assert (
821
+ padding_idx >= -self.num_embeddings
822
+ ), "Padding_idx must be within num_embeddings"
823
+ padding_idx = self.num_embeddings + padding_idx
824
+ self.padding_idx = padding_idx
825
+ self.scale_grad_by_freq = scale_grad_by_freq
826
+
827
+ self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
828
+ self.sparse = sparse
829
+
830
+ self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
831
+ self.reset_parameters(initial_speed)
832
+
833
+ def reset_parameters(self, initial_speed: float = 1.0) -> None:
834
+ std = 0.1 / initial_speed
835
+ nn.init.normal_(self.weight, std=std)
836
+ nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
837
+
838
+ if self.padding_idx is not None:
839
+ with torch.no_grad():
840
+ self.weight[self.padding_idx].fill_(0)
841
+
842
+ def forward(self, input: Tensor) -> Tensor:
843
+ F = torch.nn.functional
844
+ scale = self.scale.exp()
845
+ if input.numel() < self.num_embeddings:
846
+ return (
847
+ F.embedding(
848
+ input,
849
+ self.weight,
850
+ self.padding_idx,
851
+ None,
852
+ 2.0, # None, 2.0 relate to normalization
853
+ self.scale_grad_by_freq,
854
+ self.sparse,
855
+ )
856
+ * scale
857
+ )
858
+ else:
859
+ return F.embedding(
860
+ input,
861
+ self.weight * scale,
862
+ self.padding_idx,
863
+ None,
864
+ 2.0, # None, 2.0 relates to normalization
865
+ self.scale_grad_by_freq,
866
+ self.sparse,
867
+ )
868
+
869
+ def extra_repr(self) -> str:
870
+ # s = "{num_embeddings}, {embedding_dim}, scale={scale}"
871
+ s = "{num_embeddings}, {embedding_dim}"
872
+ if self.padding_idx is not None:
873
+ s += ", padding_idx={padding_idx}"
874
+ if self.scale_grad_by_freq is not False:
875
+ s += ", scale_grad_by_freq={scale_grad_by_freq}"
876
+ if self.sparse is not False:
877
+ s += ", sparse=True"
878
+ return s.format(**self.__dict__)
879
+
880
+
881
+ def _test_activation_balancer_sign():
882
+ probs = torch.arange(0, 1, 0.01)
883
+ N = 1000
884
+ x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
885
+ x = x.detach()
886
+ x.requires_grad = True
887
+ m = ActivationBalancer(
888
+ channel_dim=0,
889
+ min_positive=0.05,
890
+ max_positive=0.95,
891
+ max_factor=0.2,
892
+ min_abs=0.0,
893
+ )
894
+
895
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
896
+
897
+ y = m(x)
898
+ y.backward(gradient=y_grad)
899
+ print("_test_activation_balancer_sign: x = ", x)
900
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
901
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
902
+
903
+
904
+ def _test_activation_balancer_magnitude():
905
+ magnitudes = torch.arange(0, 1, 0.01)
906
+ N = 1000
907
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
908
+ x = x.detach()
909
+ x.requires_grad = True
910
+ m = ActivationBalancer(
911
+ channel_dim=0,
912
+ min_positive=0.0,
913
+ max_positive=1.0,
914
+ max_factor=0.2,
915
+ min_abs=0.2,
916
+ max_abs=0.8,
917
+ )
918
+
919
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
920
+
921
+ y = m(x)
922
+ y.backward(gradient=y_grad)
923
+ print("_test_activation_balancer_magnitude: x = ", x)
924
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
925
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
926
+
927
+
928
+ def _test_basic_norm():
929
+ num_channels = 128
930
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
931
+
932
+ x = torch.randn(500, num_channels)
933
+
934
+ y = m(x)
935
+
936
+ assert y.shape == x.shape
937
+ x_rms = (x**2).mean().sqrt()
938
+ y_rms = (y**2).mean().sqrt()
939
+ print("x rms = ", x_rms)
940
+ print("y rms = ", y_rms)
941
+ assert y_rms < x_rms
942
+ assert y_rms > 0.5 * x_rms
943
+
944
+
945
+ def _test_double_swish_deriv():
946
+ x = torch.randn(10, 12, dtype=torch.double) * 0.5
947
+ x.requires_grad = True
948
+ m = DoubleSwish()
949
+ torch.autograd.gradcheck(m, x)
950
+
951
+
952
+ def _test_scaled_lstm():
953
+ N, L = 2, 30
954
+ dim_in, dim_hidden = 10, 20
955
+ m = ScaledLSTM(input_size=dim_in, hidden_size=dim_hidden, bias=True)
956
+ x = torch.randn(L, N, dim_in)
957
+ h0 = torch.randn(1, N, dim_hidden)
958
+ c0 = torch.randn(1, N, dim_hidden)
959
+ y, (h, c) = m(x, (h0, c0))
960
+ assert y.shape == (L, N, dim_hidden)
961
+ assert h.shape == (1, N, dim_hidden)
962
+ assert c.shape == (1, N, dim_hidden)
963
+
964
+
965
+ def _test_grad_filter():
966
+ threshold = 50.0
967
+ time, batch, channel = 200, 5, 128
968
+ grad_filter = GradientFilter(batch_dim=1, threshold=threshold)
969
+
970
+ for i in range(2):
971
+ x = torch.randn(time, batch, channel, requires_grad=True)
972
+ w = nn.Parameter(torch.ones(5))
973
+ b = nn.Parameter(torch.zeros(5))
974
+
975
+ x_out, w_out, b_out = grad_filter(x, w, b)
976
+
977
+ w_out_grad = torch.randn_like(w)
978
+ b_out_grad = torch.randn_like(b)
979
+ x_out_grad = torch.rand_like(x)
980
+ if i % 2 == 1:
981
+ # The gradient norm of the first element must be larger than
982
+ # `threshold * median`, where `median` is the median value
983
+ # of gradient norms of all elements in batch.
984
+ x_out_grad[:, 0, :] = torch.full((time, channel), threshold)
985
+
986
+ torch.autograd.backward(
987
+ [x_out, w_out, b_out], [x_out_grad, w_out_grad, b_out_grad]
988
+ )
989
+
990
+ print(
991
+ "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa
992
+ i % 2 == 1,
993
+ )
994
+
995
+ print(
996
+ "_test_grad_filter: x_out_grad norm = ",
997
+ (x_out_grad**2).mean(dim=(0, 2)).sqrt(),
998
+ )
999
+ print(
1000
+ "_test_grad_filter: x.grad norm = ",
1001
+ (x.grad**2).mean(dim=(0, 2)).sqrt(),
1002
+ )
1003
+ print("_test_grad_filter: w_out_grad = ", w_out_grad)
1004
+ print("_test_grad_filter: w.grad = ", w.grad)
1005
+ print("_test_grad_filter: b_out_grad = ", b_out_grad)
1006
+ print("_test_grad_filter: b.grad = ", b.grad)
1007
+
1008
+
1009
+ if __name__ == "__main__":
1010
+ _test_activation_balancer_sign()
1011
+ _test_activation_balancer_magnitude()
1012
+ _test_basic_norm()
1013
+ _test_double_swish_deriv()
1014
+ _test_scaled_lstm()
1015
+ _test_grad_filter()
err2020/conformer_ctc3/test_model.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ """
20
+ To run this file, do:
21
+
22
+ cd icefall/egs/librispeech/ASR
23
+ python ./conformer_ctc3/test_model.py
24
+ """
25
+
26
+ import torch
27
+
28
+ from train import get_params, get_ctc_model
29
+
30
+
31
+ def test_model():
32
+ params = get_params()
33
+ params.vocab_size = 500
34
+ params.blank_id = 0
35
+ params.context_size = 2
36
+ params.unk_id = 2
37
+
38
+ params.dynamic_chunk_training = False
39
+ params.short_chunk_size = 25
40
+ params.num_left_chunks = 4
41
+ params.causal_convolution = False
42
+
43
+ model = get_ctc_model(params)
44
+
45
+ num_param = sum([p.numel() for p in model.parameters()])
46
+ print(f"Number of model parameters: {num_param}")
47
+
48
+ features = torch.randn(2, 100, 80)
49
+ feature_lengths = torch.full((2,), 100)
50
+ model(x=features, x_lens=feature_lengths)
51
+
52
+
53
+ def test_model_streaming():
54
+ params = get_params()
55
+ params.vocab_size = 500
56
+ params.blank_id = 0
57
+ params.context_size = 2
58
+ params.unk_id = 2
59
+
60
+ params.dynamic_chunk_training = True
61
+ params.short_chunk_size = 25
62
+ params.num_left_chunks = 4
63
+ params.causal_convolution = True
64
+
65
+ model = get_ctc_model(params)
66
+
67
+ num_param = sum([p.numel() for p in model.parameters()])
68
+ print(f"Number of model parameters: {num_param}")
69
+
70
+ features = torch.randn(2, 100, 80)
71
+ feature_lengths = torch.full((2,), 100)
72
+ encoder_out, _ = model.encoder(x=features, x_lens=feature_lengths)
73
+ model.get_ctc_output(encoder_out)
74
+
75
+
76
+ def main():
77
+ test_model()
78
+ test_model_streaming()
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()
err2020/conformer_ctc3/train.py ADDED
@@ -0,0 +1,1109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
3
+ # Wei Kang,
4
+ # Mingshuang Luo,)
5
+ # Zengwei Yao)
6
+ #
7
+ # See ../../../../LICENSE for clarification regarding multiple authors
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """
21
+ Usage:
22
+
23
+ export CUDA_VISIBLE_DEVICES="0,1,2,3"
24
+
25
+ ./conformer_ctc3/train.py \
26
+ --world-size 4 \
27
+ --num-epochs 30 \
28
+ --start-epoch 1 \
29
+ --exp-dir conformer_ctc3/exp \
30
+ --full-libri 1 \
31
+ --max-duration 300
32
+
33
+ # For mix precision training:
34
+
35
+ ./conformer_ctc3/train.py \
36
+ --world-size 4 \
37
+ --num-epochs 30 \
38
+ --start-epoch 1 \
39
+ --use-fp16 1 \
40
+ --exp-dir conformer_ctc3/exp \
41
+ --full-libri 1 \
42
+ --max-duration 550
43
+
44
+ # train a streaming model
45
+ ./conformer_ctc3/train.py \
46
+ --world-size 4 \
47
+ --num-epochs 30 \
48
+ --start-epoch 1 \
49
+ --exp-dir conformer_ctc3/exp \
50
+ --full-libri 1 \
51
+ --dynamic-chunk-training 1 \
52
+ --causal-convolution 1 \
53
+ --short-chunk-size 25 \
54
+ --num-left-chunks 4 \
55
+ --max-duration 300 \
56
+ --delay-penalty 0.0
57
+ """
58
+
59
+ import argparse
60
+ import copy
61
+ import logging
62
+ from pathlib import Path
63
+ from shutil import copyfile
64
+ from typing import Any, Dict, Optional, Tuple, Union
65
+
66
+ import k2
67
+ import optim
68
+ import torch
69
+ import torch.multiprocessing as mp
70
+ import torch.nn as nn
71
+ from asr_datamodule import LibriSpeechAsrDataModule
72
+ from conformer import Conformer
73
+ from lhotse.cut import Cut
74
+ from lhotse.dataset.sampling.base import CutSampler
75
+ from lhotse.utils import fix_random_seed
76
+ from model import CTCModel
77
+ from optim import Eden, Eve
78
+ from torch import Tensor
79
+ from torch.cuda.amp import GradScaler
80
+ from torch.nn.parallel import DistributedDataParallel as DDP
81
+ from torch.utils.tensorboard import SummaryWriter
82
+
83
+ from icefall import diagnostics
84
+ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
85
+ from icefall.checkpoint import load_checkpoint, remove_checkpoints
86
+ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
87
+ from icefall.checkpoint import (
88
+ save_checkpoint_with_global_batch_idx,
89
+ update_averaged_model,
90
+ )
91
+ from icefall.dist import cleanup_dist, setup_dist
92
+ from icefall.env import get_env_info
93
+ from icefall.graph_compiler import CtcTrainingGraphCompiler
94
+ from icefall.lexicon import Lexicon
95
+ from icefall.utils import (
96
+ AttributeDict,
97
+ MetricsTracker,
98
+ encode_supervisions,
99
+ setup_logger,
100
+ str2bool,
101
+ )
102
+
103
+ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
104
+
105
+
106
+ def add_model_arguments(parser: argparse.ArgumentParser):
107
+ parser.add_argument(
108
+ "--dynamic-chunk-training",
109
+ type=str2bool,
110
+ default=False,
111
+ help="""Whether to use dynamic_chunk_training, if you want a streaming
112
+ model, this requires to be True.
113
+ """,
114
+ )
115
+
116
+ parser.add_argument(
117
+ "--causal-convolution",
118
+ type=str2bool,
119
+ default=False,
120
+ help="""Whether to use causal convolution, this requires to be True when
121
+ using dynamic_chunk_training.
122
+ """,
123
+ )
124
+
125
+ parser.add_argument(
126
+ "--short-chunk-size",
127
+ type=int,
128
+ default=25,
129
+ help="""Chunk length of dynamic training, the chunk size would be either
130
+ max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
131
+ """,
132
+ )
133
+
134
+ parser.add_argument(
135
+ "--num-left-chunks",
136
+ type=int,
137
+ default=4,
138
+ help="How many left context can be seen in chunks when calculating attention.",
139
+ )
140
+
141
+
142
+ def get_parser():
143
+ parser = argparse.ArgumentParser(
144
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--world-size",
149
+ type=int,
150
+ default=1,
151
+ help="Number of GPUs for DDP training.",
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--master-port",
156
+ type=int,
157
+ default=12354,
158
+ help="Master port to use for DDP training.",
159
+ )
160
+
161
+ parser.add_argument(
162
+ "--tensorboard",
163
+ type=str2bool,
164
+ default=True,
165
+ help="Should various information be logged in tensorboard.",
166
+ )
167
+
168
+ parser.add_argument(
169
+ "--num-epochs",
170
+ type=int,
171
+ default=30,
172
+ help="Number of epochs to train.",
173
+ )
174
+
175
+ parser.add_argument(
176
+ "--start-epoch",
177
+ type=int,
178
+ default=1,
179
+ help="""Resume training from this epoch. It should be positive.
180
+ If larger than 1, it will load checkpoint from
181
+ exp-dir/epoch-{start_epoch-1}.pt
182
+ """,
183
+ )
184
+
185
+ parser.add_argument(
186
+ "--start-batch",
187
+ type=int,
188
+ default=0,
189
+ help="""If positive, --start-epoch is ignored and
190
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
191
+ """,
192
+ )
193
+
194
+ parser.add_argument(
195
+ "--exp-dir",
196
+ type=str,
197
+ default="conformer_ctc3/exp",
198
+ help="""The experiment dir.
199
+ It specifies the directory where all training related
200
+ files, e.g., checkpoints, log, etc, are saved
201
+ """,
202
+ )
203
+
204
+ parser.add_argument(
205
+ "--lang-dir",
206
+ type=str,
207
+ default="data/lang_bpe_500",
208
+ help="""The lang dir
209
+ It contains language related input files such as
210
+ "lexicon.txt"
211
+ """,
212
+ )
213
+
214
+ parser.add_argument(
215
+ "--initial-lr",
216
+ type=float,
217
+ default=0.003,
218
+ help="""The initial learning rate. This value should not need to be
219
+ changed.""",
220
+ )
221
+
222
+ parser.add_argument(
223
+ "--lr-batches",
224
+ type=float,
225
+ default=5000,
226
+ help="""Number of steps that affects how rapidly the learning rate decreases.
227
+ We suggest not to change this.""",
228
+ )
229
+
230
+ parser.add_argument(
231
+ "--lr-epochs",
232
+ type=float,
233
+ default=6,
234
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
235
+ """,
236
+ )
237
+
238
+ parser.add_argument(
239
+ "--seed",
240
+ type=int,
241
+ default=42,
242
+ help="The seed for random generators intended for reproducibility",
243
+ )
244
+
245
+ parser.add_argument(
246
+ "--print-diagnostics",
247
+ type=str2bool,
248
+ default=False,
249
+ help="Accumulate stats on activations, print them and exit.",
250
+ )
251
+
252
+ parser.add_argument(
253
+ "--save-every-n",
254
+ type=int,
255
+ default=8000,
256
+ help="""Save checkpoint after processing this number of batches"
257
+ periodically. We save checkpoint to exp-dir/ whenever
258
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
259
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
260
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
261
+ end of each epoch where `xxx` is the epoch number counting from 0.
262
+ """,
263
+ )
264
+
265
+ parser.add_argument(
266
+ "--keep-last-k",
267
+ type=int,
268
+ default=20,
269
+ help="""Only keep this number of checkpoints on disk.
270
+ For instance, if it is 3, there are only 3 checkpoints
271
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
272
+ It does not affect checkpoints with name `epoch-xxx.pt`.
273
+ """,
274
+ )
275
+
276
+ parser.add_argument(
277
+ "--average-period",
278
+ type=int,
279
+ default=100,
280
+ help="""Update the averaged model, namely `model_avg`, after processing
281
+ this number of batches. `model_avg` is a separate version of model,
282
+ in which each floating-point parameter is the average of all the
283
+ parameters from the start of training. Each time we take the average,
284
+ we do: `model_avg = model * (average_period / batch_idx_train) +
285
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
286
+ """,
287
+ )
288
+
289
+ parser.add_argument(
290
+ "--use-fp16",
291
+ type=str2bool,
292
+ default=False,
293
+ help="Whether to use half precision training.",
294
+ )
295
+
296
+ parser.add_argument(
297
+ "--delay-penalty",
298
+ type=float,
299
+ default=0.0,
300
+ help="""A constant used to scale the symbol delay penalty,
301
+ to encourage symbol emit earlier for streaming models.
302
+ It is almost the same as the `delay_penalty` in our `rnnt_loss`, See
303
+ https://github.com/k2-fsa/k2/issues/955 and
304
+ https://arxiv.org/pdf/2211.00490.pdf for more details.""",
305
+ )
306
+
307
+ parser.add_argument(
308
+ "--nnet-delay-penalty",
309
+ type=float,
310
+ default=0.0,
311
+ help="""A constant to penalize symbol delay, which is applied on
312
+ the nnet_output after log-softmax.
313
+ We recommend using --delay-penalty instead.
314
+ See https://github.com/k2-fsa/icefall/pull/669 for details.""",
315
+ )
316
+
317
+ add_model_arguments(parser)
318
+
319
+ return parser
320
+
321
+
322
+ def get_params() -> AttributeDict:
323
+ """Return a dict containing training parameters.
324
+
325
+ All training related parameters that are not passed from the commandline
326
+ are saved in the variable `params`.
327
+
328
+ Commandline options are merged into `params` after they are parsed, so
329
+ you can also access them via `params`.
330
+
331
+ Explanation of options saved in `params`:
332
+
333
+ - best_train_loss: Best training loss so far. It is used to select
334
+ the model that has the lowest training loss. It is
335
+ updated during the training.
336
+
337
+ - best_valid_loss: Best validation loss so far. It is used to select
338
+ the model that has the lowest validation loss. It is
339
+ updated during the training.
340
+
341
+ - best_train_epoch: It is the epoch that has the best training loss.
342
+
343
+ - best_valid_epoch: It is the epoch that has the best validation loss.
344
+
345
+ - batch_idx_train: Used to writing statistics to tensorboard. It
346
+ contains number of batches trained so far across
347
+ epochs.
348
+
349
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
350
+
351
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
352
+
353
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
354
+
355
+ - feature_dim: The model input dim. It has to match the one used
356
+ in computing features.
357
+
358
+ - subsampling_factor: The subsampling factor for the model.
359
+
360
+ - encoder_dim: Hidden dim for multi-head attention model.
361
+
362
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
363
+
364
+ - warm_step: The warm_step for Noam optimizer.
365
+ """
366
+ params = AttributeDict(
367
+ {
368
+ "best_train_loss": float("inf"),
369
+ "best_valid_loss": float("inf"),
370
+ "best_train_epoch": -1,
371
+ "best_valid_epoch": -1,
372
+ "batch_idx_train": 0,
373
+ "log_interval": 50,
374
+ "reset_interval": 200,
375
+ "valid_interval": 3000, # For the 100h subset, use 800
376
+ # parameters for conformer
377
+ "feature_dim": 80,
378
+ "subsampling_factor": 4,
379
+ "encoder_dim": 512,
380
+ "nhead": 8,
381
+ "dim_feedforward": 2048,
382
+ "num_encoder_layers": 12,
383
+ # parameters for loss
384
+ "beam_size": 10,
385
+ "reduction": "none",
386
+ "use_double_scores": True,
387
+ # parameters for Noam
388
+ "model_warm_step": 3000, # arg given to model, not for lrate
389
+ "env_info": get_env_info(),
390
+ }
391
+ )
392
+
393
+ return params
394
+
395
+
396
+ def get_encoder_model(params: AttributeDict) -> nn.Module:
397
+ # TODO: We can add an option to switch between Conformer and Transformer
398
+ encoder = Conformer(
399
+ num_features=params.feature_dim,
400
+ subsampling_factor=params.subsampling_factor,
401
+ d_model=params.encoder_dim,
402
+ nhead=params.nhead,
403
+ dim_feedforward=params.dim_feedforward,
404
+ num_encoder_layers=params.num_encoder_layers,
405
+ dynamic_chunk_training=params.dynamic_chunk_training,
406
+ short_chunk_size=params.short_chunk_size,
407
+ num_left_chunks=params.num_left_chunks,
408
+ causal=params.causal_convolution,
409
+ )
410
+ return encoder
411
+
412
+
413
+ def get_ctc_model(params: AttributeDict) -> nn.Module:
414
+ encoder = get_encoder_model(params)
415
+ model = CTCModel(
416
+ encoder=encoder,
417
+ encoder_dim=params.encoder_dim,
418
+ vocab_size=params.vocab_size,
419
+ )
420
+ return model
421
+
422
+
423
+ def load_checkpoint_if_available(
424
+ params: AttributeDict,
425
+ model: nn.Module,
426
+ model_avg: nn.Module = None,
427
+ optimizer: Optional[torch.optim.Optimizer] = None,
428
+ scheduler: Optional[LRSchedulerType] = None,
429
+ ) -> Optional[Dict[str, Any]]:
430
+ """Load checkpoint from file.
431
+
432
+ If params.start_batch is positive, it will load the checkpoint from
433
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
434
+ params.start_epoch is larger than 1, it will load the checkpoint from
435
+ `params.start_epoch - 1`.
436
+
437
+ Apart from loading state dict for `model` and `optimizer` it also updates
438
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
439
+ and `best_valid_loss` in `params`.
440
+
441
+ Args:
442
+ params:
443
+ The return value of :func:`get_params`.
444
+ model:
445
+ The training model.
446
+ model_avg:
447
+ The stored model averaged from the start of training.
448
+ optimizer:
449
+ The optimizer that we are using.
450
+ scheduler:
451
+ The scheduler that we are using.
452
+ Returns:
453
+ Return a dict containing previously saved training info.
454
+ """
455
+ if params.start_batch > 0:
456
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
457
+ elif params.start_epoch > 1:
458
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
459
+ else:
460
+ return None
461
+
462
+ assert filename.is_file(), f"{filename} does not exist!"
463
+
464
+ saved_params = load_checkpoint(
465
+ filename,
466
+ model=model,
467
+ model_avg=model_avg,
468
+ optimizer=optimizer,
469
+ scheduler=scheduler,
470
+ )
471
+
472
+ keys = [
473
+ "best_train_epoch",
474
+ "best_valid_epoch",
475
+ "batch_idx_train",
476
+ "best_train_loss",
477
+ "best_valid_loss",
478
+ ]
479
+ for k in keys:
480
+ params[k] = saved_params[k]
481
+
482
+ if params.start_batch > 0:
483
+ if "cur_epoch" in saved_params:
484
+ params["start_epoch"] = saved_params["cur_epoch"]
485
+
486
+ return saved_params
487
+
488
+
489
+ def save_checkpoint(
490
+ params: AttributeDict,
491
+ model: Union[nn.Module, DDP],
492
+ model_avg: Optional[nn.Module] = None,
493
+ optimizer: Optional[torch.optim.Optimizer] = None,
494
+ scheduler: Optional[LRSchedulerType] = None,
495
+ sampler: Optional[CutSampler] = None,
496
+ scaler: Optional[GradScaler] = None,
497
+ rank: int = 0,
498
+ ) -> None:
499
+ """Save model, optimizer, scheduler and training stats to file.
500
+
501
+ Args:
502
+ params:
503
+ It is returned by :func:`get_params`.
504
+ model:
505
+ The training model.
506
+ model_avg:
507
+ The stored model averaged from the start of training.
508
+ optimizer:
509
+ The optimizer used in the training.
510
+ sampler:
511
+ The sampler for the training dataset.
512
+ scaler:
513
+ The scaler used for mix precision training.
514
+ """
515
+ if rank != 0:
516
+ return
517
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
518
+ save_checkpoint_impl(
519
+ filename=filename,
520
+ model=model,
521
+ model_avg=model_avg,
522
+ params=params,
523
+ optimizer=optimizer,
524
+ scheduler=scheduler,
525
+ sampler=sampler,
526
+ scaler=scaler,
527
+ rank=rank,
528
+ )
529
+
530
+ if params.best_train_epoch == params.cur_epoch:
531
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
532
+ copyfile(src=filename, dst=best_train_filename)
533
+
534
+ if params.best_valid_epoch == params.cur_epoch:
535
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
536
+ copyfile(src=filename, dst=best_valid_filename)
537
+
538
+
539
+ def compute_loss(
540
+ params: AttributeDict,
541
+ model: Union[nn.Module, DDP],
542
+ graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
543
+ batch: dict,
544
+ is_training: bool,
545
+ warmup: float = 1.0,
546
+ ) -> Tuple[Tensor, MetricsTracker]:
547
+ """
548
+ Compute RNN-T loss given the model and its inputs.
549
+
550
+ Args:
551
+ params:
552
+ Parameters for training. See :func:`get_params`.
553
+ model:
554
+ The model for training. It is an instance of Conformer in our case.
555
+ graph_compiler:
556
+ It is used to build a decoding graph from a ctc topo and training
557
+ transcript. The training transcript is contained in the given `batch`,
558
+ while the ctc topo is built when this compiler is instantiated.
559
+ batch:
560
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
561
+ for the content in it.
562
+ is_training:
563
+ True for training. False for validation. When it is True, this
564
+ function enables autograd during computation; when it is False, it
565
+ disables autograd.
566
+ warmup: a floating point value which increases throughout training;
567
+ values >= 1.0 are fully warmed up and have all modules present.
568
+ """
569
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
570
+ feature = batch["inputs"]
571
+ # at entry, feature is (N, T, C)
572
+ assert feature.ndim == 3
573
+ feature = feature.to(device)
574
+
575
+ supervisions = batch["supervisions"]
576
+ feature_lens = supervisions["num_frames"].to(device)
577
+
578
+ with torch.set_grad_enabled(is_training):
579
+ nnet_output, encoder_out_lens = model(
580
+ feature,
581
+ feature_lens,
582
+ warmup=warmup,
583
+ delay_penalty=params.nnet_delay_penalty if warmup >= 1.0 else 0,
584
+ )
585
+ assert torch.all(encoder_out_lens > 0)
586
+
587
+ # NOTE: We need `encode_supervisions` to sort sequences with
588
+ # different duration in decreasing order, required by
589
+ # `k2.intersect_dense` called in `k2.ctc_loss`
590
+ supervision_segments, texts = encode_supervisions(
591
+ supervisions, subsampling_factor=params.subsampling_factor
592
+ )
593
+
594
+ if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
595
+ # Works with a BPE model
596
+ token_ids = graph_compiler.texts_to_ids(texts)
597
+ decoding_graph = graph_compiler.compile(token_ids)
598
+ elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
599
+ # Works with a phone lexicon
600
+ decoding_graph = graph_compiler.compile(texts)
601
+ else:
602
+ raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}")
603
+
604
+ dense_fsa_vec = k2.DenseFsaVec(
605
+ nnet_output,
606
+ supervision_segments,
607
+ allow_truncate=params.subsampling_factor - 1,
608
+ )
609
+
610
+ ctc_loss = k2.ctc_loss(
611
+ decoding_graph=decoding_graph,
612
+ dense_fsa_vec=dense_fsa_vec,
613
+ output_beam=params.beam_size,
614
+ delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0,
615
+ reduction=params.reduction,
616
+ use_double_scores=params.use_double_scores,
617
+ )
618
+ ctc_loss_is_finite = torch.isfinite(ctc_loss)
619
+ if not torch.all(ctc_loss_is_finite):
620
+ logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}")
621
+ ctc_loss = ctc_loss[ctc_loss_is_finite]
622
+
623
+ # If either all simple_loss or pruned_loss is inf or nan,
624
+ # we stop the training process by raising an exception
625
+ if torch.all(~ctc_loss_is_finite):
626
+ raise ValueError(
627
+ "There are too many utterances in this batch "
628
+ "leading to inf or nan losses."
629
+ )
630
+ loss = ctc_loss.sum()
631
+
632
+ assert loss.requires_grad == is_training
633
+
634
+ info = MetricsTracker()
635
+ # info["frames"] is an approximate number for two reasons:
636
+ # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
637
+ # (2) If some utterances in the batch lead to inf/nan loss, they
638
+ # are filtered out.
639
+ info["frames"] = supervision_segments[:, 2].sum().item()
640
+ # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
641
+ info["utterances"] = feature.size(0)
642
+ # averaged input duration in frames over utterances
643
+ info["utt_duration"] = feature_lens.sum().item()
644
+ # averaged padding proportion over utterances
645
+ info["utt_pad_proportion"] = (
646
+ ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
647
+ )
648
+
649
+ # Note: We use reduction=sum while computing the loss.
650
+ info["loss"] = loss.detach().cpu().item()
651
+
652
+ return loss, info
653
+
654
+
655
+ def compute_validation_loss(
656
+ params: AttributeDict,
657
+ model: Union[nn.Module, DDP],
658
+ graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
659
+ valid_dl: torch.utils.data.DataLoader,
660
+ world_size: int = 1,
661
+ ) -> MetricsTracker:
662
+ """Run the validation process."""
663
+ model.eval()
664
+
665
+ tot_loss = MetricsTracker()
666
+
667
+ for batch_idx, batch in enumerate(valid_dl):
668
+ loss, loss_info = compute_loss(
669
+ params=params,
670
+ model=model,
671
+ graph_compiler=graph_compiler,
672
+ batch=batch,
673
+ is_training=False,
674
+ )
675
+ assert loss.requires_grad is False
676
+ tot_loss = tot_loss + loss_info
677
+
678
+ if world_size > 1:
679
+ tot_loss.reduce(loss.device)
680
+
681
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
682
+ if loss_value < params.best_valid_loss:
683
+ params.best_valid_epoch = params.cur_epoch
684
+ params.best_valid_loss = loss_value
685
+
686
+ return tot_loss
687
+
688
+
689
+ def train_one_epoch(
690
+ params: AttributeDict,
691
+ model: Union[nn.Module, DDP],
692
+ optimizer: torch.optim.Optimizer,
693
+ scheduler: LRSchedulerType,
694
+ graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
695
+ train_dl: torch.utils.data.DataLoader,
696
+ valid_dl: torch.utils.data.DataLoader,
697
+ scaler: GradScaler,
698
+ model_avg: Optional[nn.Module] = None,
699
+ tb_writer: Optional[SummaryWriter] = None,
700
+ world_size: int = 1,
701
+ rank: int = 0,
702
+ ) -> None:
703
+ """Train the model for one epoch.
704
+
705
+ The training loss from the mean of all frames is saved in
706
+ `params.train_loss`. It runs the validation process every
707
+ `params.valid_interval` batches.
708
+
709
+ Args:
710
+ params:
711
+ It is returned by :func:`get_params`.
712
+ model:
713
+ The model for training.
714
+ optimizer:
715
+ The optimizer we are using.
716
+ scheduler:
717
+ The learning rate scheduler, we call step() every step.
718
+ graph_compiler:
719
+ It is used to build a decoding graph from a ctc topo and training
720
+ transcript. The training transcript is contained in the given `batch`,
721
+ while the ctc topo is built when this compiler is instantiated.
722
+ train_dl:
723
+ Dataloader for the training dataset.
724
+ valid_dl:
725
+ Dataloader for the validation dataset.
726
+ scaler:
727
+ The scaler used for mix precision training.
728
+ model_avg:
729
+ The stored model averaged from the start of training.
730
+ tb_writer:
731
+ Writer to write log messages to tensorboard.
732
+ world_size:
733
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
734
+ rank:
735
+ The rank of the node in DDP training. If no DDP is used, it should
736
+ be set to 0.
737
+ """
738
+ model.train()
739
+
740
+ tot_loss = MetricsTracker()
741
+
742
+ for batch_idx, batch in enumerate(train_dl):
743
+ params.batch_idx_train += 1
744
+ batch_size = len(batch["supervisions"]["text"])
745
+
746
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
747
+ loss, loss_info = compute_loss(
748
+ params=params,
749
+ model=model,
750
+ graph_compiler=graph_compiler,
751
+ batch=batch,
752
+ is_training=True,
753
+ warmup=(params.batch_idx_train / params.model_warm_step),
754
+ )
755
+ # summary stats
756
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
757
+
758
+ # NOTE: We use reduction==sum and loss is computed over utterances
759
+ # in the batch and there is no normalization to it so far.
760
+ scaler.scale(loss).backward()
761
+ scheduler.step_batch(params.batch_idx_train)
762
+ scaler.step(optimizer)
763
+ scaler.update()
764
+ optimizer.zero_grad()
765
+
766
+ if params.print_diagnostics and batch_idx == 30:
767
+ return
768
+
769
+ if (
770
+ rank == 0
771
+ and params.batch_idx_train > 0
772
+ and params.batch_idx_train % params.average_period == 0
773
+ ):
774
+ update_averaged_model(
775
+ params=params,
776
+ model_cur=model,
777
+ model_avg=model_avg,
778
+ )
779
+
780
+ if (
781
+ params.batch_idx_train > 0
782
+ and params.batch_idx_train % params.save_every_n == 0
783
+ ):
784
+ save_checkpoint_with_global_batch_idx(
785
+ out_dir=params.exp_dir,
786
+ global_batch_idx=params.batch_idx_train,
787
+ model=model,
788
+ model_avg=model_avg,
789
+ params=params,
790
+ optimizer=optimizer,
791
+ scheduler=scheduler,
792
+ sampler=train_dl.sampler,
793
+ scaler=scaler,
794
+ rank=rank,
795
+ )
796
+ remove_checkpoints(
797
+ out_dir=params.exp_dir,
798
+ topk=params.keep_last_k,
799
+ rank=rank,
800
+ )
801
+
802
+ if batch_idx % params.log_interval == 0:
803
+ cur_lr = scheduler.get_last_lr()[0]
804
+ logging.info(
805
+ f"Epoch {params.cur_epoch}, "
806
+ f"batch {batch_idx}, loss[{loss_info}], "
807
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
808
+ f"lr: {cur_lr:.2e}"
809
+ )
810
+
811
+ if tb_writer is not None:
812
+ tb_writer.add_scalar(
813
+ "train/learning_rate", cur_lr, params.batch_idx_train
814
+ )
815
+
816
+ loss_info.write_summary(
817
+ tb_writer, "train/current_", params.batch_idx_train
818
+ )
819
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
820
+
821
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
822
+ logging.info("Computing validation loss")
823
+ valid_info = compute_validation_loss(
824
+ params=params,
825
+ model=model,
826
+ graph_compiler=graph_compiler,
827
+ valid_dl=valid_dl,
828
+ world_size=world_size,
829
+ )
830
+ model.train()
831
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
832
+ if tb_writer is not None:
833
+ valid_info.write_summary(
834
+ tb_writer, "train/valid_", params.batch_idx_train
835
+ )
836
+
837
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
838
+ params.train_loss = loss_value
839
+ if params.train_loss < params.best_train_loss:
840
+ params.best_train_epoch = params.cur_epoch
841
+ params.best_train_loss = params.train_loss
842
+
843
+
844
+ def run(rank, world_size, args):
845
+ """
846
+ Args:
847
+ rank:
848
+ It is a value between 0 and `world_size-1`, which is
849
+ passed automatically by `mp.spawn()` in :func:`main`.
850
+ The node with rank 0 is responsible for saving checkpoint.
851
+ world_size:
852
+ Number of GPUs for DDP training.
853
+ args:
854
+ The return value of get_parser().parse_args()
855
+ """
856
+ params = get_params()
857
+ params.update(vars(args))
858
+ if params.full_libri is False:
859
+ params.valid_interval = 1600
860
+
861
+ fix_random_seed(params.seed)
862
+ if world_size > 1:
863
+ setup_dist(rank, world_size, params.master_port)
864
+
865
+ setup_logger(f"{params.exp_dir}/log/log-train")
866
+ logging.info("Training started")
867
+
868
+ if args.tensorboard and rank == 0:
869
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
870
+ else:
871
+ tb_writer = None
872
+
873
+ lexicon = Lexicon(params.lang_dir)
874
+ max_token_id = max(lexicon.tokens)
875
+ params.vocab_size = max_token_id + 1 # +1 for the blank
876
+
877
+ device = torch.device("cpu")
878
+ if torch.cuda.is_available():
879
+ device = torch.device("cuda", rank)
880
+ logging.info(f"Device: {device}")
881
+
882
+ if "lang_bpe" in str(params.lang_dir):
883
+ graph_compiler = BpeCtcTrainingGraphCompiler(
884
+ params.lang_dir,
885
+ device=device,
886
+ sos_token="<sos/eos>",
887
+ eos_token="<sos/eos>",
888
+ )
889
+ elif "lang_phone" in str(params.lang_dir):
890
+ graph_compiler = CtcTrainingGraphCompiler(
891
+ lexicon,
892
+ device=device,
893
+ need_repeat_flag=params.delay_penalty > 0,
894
+ )
895
+ # Manually add the sos/eos ID with their default values
896
+ # from the BPE recipe which we're adapting here.
897
+ graph_compiler.sos_id = 1
898
+ graph_compiler.eos_id = 1
899
+ else:
900
+ raise ValueError(
901
+ f"Unsupported type of lang dir (we expected it to have "
902
+ f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
903
+ )
904
+
905
+ if params.dynamic_chunk_training:
906
+ assert (
907
+ params.causal_convolution
908
+ ), "dynamic_chunk_training requires causal convolution"
909
+
910
+ logging.info(params)
911
+
912
+ logging.info("About to create model")
913
+ model = get_ctc_model(params)
914
+
915
+ num_param = sum([p.numel() for p in model.parameters()])
916
+ logging.info(f"Number of model parameters: {num_param}")
917
+
918
+ assert params.save_every_n >= params.average_period
919
+ model_avg: Optional[nn.Module] = None
920
+ if rank == 0:
921
+ # model_avg is only used with rank 0
922
+ model_avg = copy.deepcopy(model)
923
+
924
+ assert params.start_epoch > 0, params.start_epoch
925
+ checkpoints = load_checkpoint_if_available(
926
+ params=params, model=model, model_avg=model_avg
927
+ )
928
+
929
+ model.to(device)
930
+ if world_size > 1:
931
+ logging.info("Using DDP")
932
+ model = DDP(model, device_ids=[rank])
933
+
934
+ optimizer = Eve(model.parameters(), lr=params.initial_lr)
935
+
936
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
937
+
938
+ if checkpoints and "optimizer" in checkpoints:
939
+ logging.info("Loading optimizer state dict")
940
+ optimizer.load_state_dict(checkpoints["optimizer"])
941
+
942
+ if (
943
+ checkpoints
944
+ and "scheduler" in checkpoints
945
+ and checkpoints["scheduler"] is not None
946
+ ):
947
+ logging.info("Loading scheduler state dict")
948
+ scheduler.load_state_dict(checkpoints["scheduler"])
949
+
950
+ if params.print_diagnostics:
951
+ diagnostic = diagnostics.attach_diagnostics(model)
952
+
953
+ librispeech = LibriSpeechAsrDataModule(args)
954
+
955
+ train_cuts = librispeech.train_clean_100_cuts()
956
+ # if params.full_libri:
957
+ # train_cuts += librispeech.train_clean_360_cuts()
958
+ # train_cuts += librispeech.train_other_500_cuts()
959
+
960
+ def remove_short_and_long_utt(c: Cut):
961
+ # Keep only utterances with duration between 1 second and 20 seconds
962
+ #
963
+ # Caution: There is a reason to select 20.0 here. Please see
964
+ # ../local/display_manifest_statistics.py
965
+ #
966
+ # You should use ../local/display_manifest_statistics.py to get
967
+ # an utterance duration distribution for your dataset to select
968
+ # the threshold
969
+ return 1.0 <= c.duration <= 20.0
970
+
971
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
972
+
973
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
974
+ # We only load the sampler's state dict when it loads a checkpoint
975
+ # saved in the middle of an epoch
976
+ sampler_state_dict = checkpoints["sampler"]
977
+ else:
978
+ sampler_state_dict = None
979
+
980
+ train_dl = librispeech.train_dataloaders(
981
+ train_cuts, sampler_state_dict=sampler_state_dict
982
+ )
983
+
984
+ valid_cuts = librispeech.dev_clean_cuts()
985
+ #valid_cuts += librispeech.dev_other_cuts()
986
+ valid_dl = librispeech.valid_dataloaders(valid_cuts)
987
+
988
+ if params.start_batch <= 0 and not params.print_diagnostics:
989
+ scan_pessimistic_batches_for_oom(
990
+ model=model,
991
+ train_dl=train_dl,
992
+ optimizer=optimizer,
993
+ graph_compiler=graph_compiler,
994
+ params=params,
995
+ warmup=0.0 if params.start_epoch == 1 else 1.0,
996
+ )
997
+
998
+ scaler = GradScaler(enabled=params.use_fp16)
999
+ if checkpoints and "grad_scaler" in checkpoints:
1000
+ logging.info("Loading grad scaler state dict")
1001
+ scaler.load_state_dict(checkpoints["grad_scaler"])
1002
+
1003
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
1004
+ scheduler.step_epoch(epoch - 1)
1005
+ fix_random_seed(params.seed + epoch - 1)
1006
+ train_dl.sampler.set_epoch(epoch - 1)
1007
+
1008
+ if tb_writer is not None:
1009
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
1010
+
1011
+ params.cur_epoch = epoch
1012
+
1013
+ train_one_epoch(
1014
+ params=params,
1015
+ model=model,
1016
+ model_avg=model_avg,
1017
+ optimizer=optimizer,
1018
+ scheduler=scheduler,
1019
+ graph_compiler=graph_compiler,
1020
+ train_dl=train_dl,
1021
+ valid_dl=valid_dl,
1022
+ scaler=scaler,
1023
+ tb_writer=tb_writer,
1024
+ world_size=world_size,
1025
+ rank=rank,
1026
+ )
1027
+
1028
+ if params.print_diagnostics:
1029
+ diagnostic.print_diagnostics()
1030
+ break
1031
+
1032
+ save_checkpoint(
1033
+ params=params,
1034
+ model=model,
1035
+ model_avg=model_avg,
1036
+ optimizer=optimizer,
1037
+ scheduler=scheduler,
1038
+ sampler=train_dl.sampler,
1039
+ scaler=scaler,
1040
+ rank=rank,
1041
+ )
1042
+
1043
+ logging.info("Done!")
1044
+
1045
+ if world_size > 1:
1046
+ torch.distributed.barrier()
1047
+ cleanup_dist()
1048
+
1049
+
1050
+ def scan_pessimistic_batches_for_oom(
1051
+ model: Union[nn.Module, DDP],
1052
+ train_dl: torch.utils.data.DataLoader,
1053
+ optimizer: torch.optim.Optimizer,
1054
+ graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
1055
+ params: AttributeDict,
1056
+ warmup: float,
1057
+ ):
1058
+ from lhotse.dataset import find_pessimistic_batches
1059
+
1060
+ logging.info(
1061
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
1062
+ )
1063
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
1064
+ for criterion, cuts in batches.items():
1065
+ batch = train_dl.dataset[cuts]
1066
+ try:
1067
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
1068
+ loss, _ = compute_loss(
1069
+ params=params,
1070
+ model=model,
1071
+ graph_compiler=graph_compiler,
1072
+ batch=batch,
1073
+ is_training=True,
1074
+ warmup=warmup,
1075
+ )
1076
+ loss.backward()
1077
+ optimizer.step()
1078
+ optimizer.zero_grad()
1079
+ except RuntimeError as e:
1080
+ if "CUDA out of memory" in str(e):
1081
+ logging.error(
1082
+ "Your GPU ran out of memory with the current "
1083
+ "max_duration setting. We recommend decreasing "
1084
+ "max_duration and trying again.\n"
1085
+ f"Failing criterion: {criterion} "
1086
+ f"(={crit_values[criterion]}) ..."
1087
+ )
1088
+ raise
1089
+
1090
+
1091
+ def main():
1092
+ parser = get_parser()
1093
+ LibriSpeechAsrDataModule.add_arguments(parser)
1094
+ args = parser.parse_args()
1095
+ args.exp_dir = Path(args.exp_dir)
1096
+
1097
+ world_size = args.world_size
1098
+ assert world_size >= 1
1099
+ if world_size > 1:
1100
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
1101
+ else:
1102
+ run(rank=0, world_size=1, args=args)
1103
+
1104
+
1105
+ torch.set_num_threads(1)
1106
+ torch.set_num_interop_threads(1)
1107
+
1108
+ if __name__ == "__main__":
1109
+ main()
err2020/conformer_ctc3_usage.ipynb ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "b6b6ded1-0a58-43cb-9065-4f4fae02a01b",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import argparse\n",
11
+ "import logging\n",
12
+ "import math\n",
13
+ "import re\n",
14
+ "from typing import List\n",
15
+ "import sys\n",
16
+ "sys.path.append('/opt/notebooks/err2020/conformer_ctc3/')\n",
17
+ "import k2\n",
18
+ "import kaldifeat\n",
19
+ "import sentencepiece as spm\n",
20
+ "import torch\n",
21
+ "import torchaudio\n",
22
+ "from decode import get_decoding_params\n",
23
+ "from torch.nn.utils.rnn import pad_sequence\n",
24
+ "from train import add_model_arguments, get_params\n",
25
+ "\n",
26
+ "from icefall.decode import (\n",
27
+ " get_lattice,\n",
28
+ " one_best_decoding,\n",
29
+ " rescore_with_n_best_list,\n",
30
+ " rescore_with_whole_lattice\n",
31
+ ")\n",
32
+ "from icefall.utils import get_texts, parse_fsa_timestamps_and_texts"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "id": "52514f2f-1195-4e4f-8174-d21aa7462476",
38
+ "metadata": {},
39
+ "source": [
40
+ "## Helpers"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "id": "8ec024bf-7f91-47a9-9293-822fe2765c4b",
46
+ "metadata": {},
47
+ "source": [
48
+ "#### Load args helpers"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 4,
54
+ "id": "3d69d771-b421-417f-a6ff-e1d1c64ba934",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "class Args:\n",
59
+ " model_filename='conformer_ctc3/exp/jit_trace.pt'\n",
60
+ " bpe_model_filename=\"data/lang_bpe_500/bpe.model\"\n",
61
+ " method=\"ctc-decoding\"\n",
62
+ " sample_rate=16000\n",
63
+ " num_classes=500 #bpe model size\n",
64
+ " frame_shift_ms=10\n",
65
+ " dither=0\n",
66
+ " snip_edges=False\n",
67
+ " num_bins=80\n",
68
+ " device='cpu'\n",
69
+ " \n",
70
+ " def args_from_dict(self, dct):\n",
71
+ " for key in dct:\n",
72
+ " setattr(self, key, dct[key])\n",
73
+ " \n",
74
+ " def __repr__(self):\n",
75
+ " text=''\n",
76
+ " for k, v in self.__dict__.items():\n",
77
+ " text+=f'{k} = {v}\\n'\n",
78
+ " return text"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "id": "57a3cd62-3037-4c99-9094-dd63429e660e",
84
+ "metadata": {},
85
+ "source": [
86
+ "#### Decoder helper"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 5,
92
+ "id": "48306369-fb68-4abe-be62-0806d00059f8",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "class ConformerCtc3Decoder:\n",
97
+ " def __init__(self, params_dct=None):\n",
98
+ " logging.info('loading args')\n",
99
+ " self.args=Args()\n",
100
+ " if params_dct is not None:\n",
101
+ " self.args.args_from_dict(params_dct)\n",
102
+ " logging.info('loading model')\n",
103
+ " self.load_model()\n",
104
+ " logging.info('loading fbank')\n",
105
+ " self.get_fbank()\n",
106
+ " \n",
107
+ " def update_args(self, dct):\n",
108
+ " self.args.args_from_dict(dct)\n",
109
+ " \n",
110
+ " def load_model_(self, model_filename, device):\n",
111
+ " device = torch.device(\"cpu\")\n",
112
+ " model = torch.jit.load(model_filename)\n",
113
+ " model.to(device)\n",
114
+ " model=model.eval()\n",
115
+ " self.model=model\n",
116
+ " \n",
117
+ " def load_model(self, model_filename=None, device=None):\n",
118
+ " if model_filename is not None:\n",
119
+ " self.args.model_filename=model_filename\n",
120
+ " if device is not None:\n",
121
+ " self.args.device=device\n",
122
+ " self.load_model_(self.args.model_filename, self.args.device)\n",
123
+ " \n",
124
+ " def get_fbank_(self, device='cpu'):\n",
125
+ " opts = kaldifeat.FbankOptions()\n",
126
+ " opts.device = device\n",
127
+ " opts.frame_opts.dither = self.args.dither\n",
128
+ " opts.frame_opts.snip_edges = self.args.snip_edges\n",
129
+ " #opts.frame_opts.samp_freq = sample_rate\n",
130
+ " opts.mel_opts.num_bins = self.args.num_bins\n",
131
+ "\n",
132
+ " fbank = kaldifeat.Fbank(opts)\n",
133
+ " return fbank\n",
134
+ " \n",
135
+ " def get_fbank(self):\n",
136
+ " self.fbank=self.get_fbank_(self.args.device)\n",
137
+ " \n",
138
+ " def read_sound_file_(self, filename: str, expected_sample_rate: float ) -> List[torch.Tensor]:\n",
139
+ " \"\"\"Read a sound file into a 1-D float32 torch tensor.\n",
140
+ " Args:\n",
141
+ " filenames:\n",
142
+ " A list of sound filenames.\n",
143
+ " expected_sample_rate:\n",
144
+ " The expected sample rate of the sound files.\n",
145
+ " Returns:\n",
146
+ " Return a 1-D float32 torch tensor.\n",
147
+ " \"\"\"\n",
148
+ " wave, sample_rate = torchaudio.load(filename)\n",
149
+ " assert sample_rate == expected_sample_rate, (\n",
150
+ " f\"expected sample rate: {expected_sample_rate}. \" f\"Given: {sample_rate}\"\n",
151
+ " )\n",
152
+ " # We use only the first channel\n",
153
+ " return wave[0]\n",
154
+ " \n",
155
+ " def format_trs(self, hyp, timestamps):\n",
156
+ " if len(hyp)!=len(timestamps):\n",
157
+ " print(f'len of hyp and timestamps is not the same len hyp {len(hyp)} and len of timestamps {len(timestamps)}')\n",
158
+ " return None\n",
159
+ " trs ={'text': ' '.join(hyp),\n",
160
+ " 'words': [{'word': w, 'start':timestamps[i][0], 'end': timestamps[i][1]} for i, w in enumerate(hyp)]\n",
161
+ " }\n",
162
+ " return trs\n",
163
+ " \n",
164
+ " def decode_(self, wave, fbank, model, device, method, bpe_model_filename, num_classes, \n",
165
+ " min_active_states, max_active_states, subsampling_factor, use_double_scores, \n",
166
+ " frame_shift_ms, search_beam, output_beam):\n",
167
+ " \n",
168
+ " wave = [wave.to(device)]\n",
169
+ " logging.info(\"Decoding started\")\n",
170
+ " features = fbank(wave)\n",
171
+ " feature_lengths = [f.size(0) for f in features]\n",
172
+ "\n",
173
+ " features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))\n",
174
+ " feature_lengths = torch.tensor(feature_lengths, device=device)\n",
175
+ "\n",
176
+ " nnet_output, _ = model(features, feature_lengths)\n",
177
+ "\n",
178
+ " batch_size = nnet_output.shape[0]\n",
179
+ " supervision_segments = torch.tensor(\n",
180
+ " [\n",
181
+ " [i, 0, feature_lengths[i] // subsampling_factor]\n",
182
+ " for i in range(batch_size)\n",
183
+ " ],\n",
184
+ " dtype=torch.int32,\n",
185
+ " )\n",
186
+ "\n",
187
+ " if method == \"ctc-decoding\":\n",
188
+ " logging.info(\"Use CTC decoding\")\n",
189
+ " bpe_model = spm.SentencePieceProcessor()\n",
190
+ " bpe_model.load(bpe_model_filename)\n",
191
+ " max_token_id = num_classes - 1\n",
192
+ "\n",
193
+ " H = k2.ctc_topo(\n",
194
+ " max_token=max_token_id,\n",
195
+ " modified=False,\n",
196
+ " device=device,\n",
197
+ " )\n",
198
+ "\n",
199
+ " lattice = get_lattice(\n",
200
+ " nnet_output=nnet_output,\n",
201
+ " decoding_graph=H,\n",
202
+ " supervision_segments=supervision_segments,\n",
203
+ " search_beam=search_beam,\n",
204
+ " output_beam=output_beam,\n",
205
+ " min_active_states=min_active_states,\n",
206
+ " max_active_states=max_active_states,\n",
207
+ " subsampling_factor=subsampling_factor,\n",
208
+ " )\n",
209
+ "\n",
210
+ " best_path = one_best_decoding(\n",
211
+ " lattice=lattice, use_double_scores=use_double_scores\n",
212
+ " )\n",
213
+ "\n",
214
+ " confidence=best_path.get_tot_scores(use_double_scores=False, log_semiring=False).detach()[0]\n",
215
+ "\n",
216
+ " timestamps, hyps = parse_fsa_timestamps_and_texts(\n",
217
+ " best_paths=best_path,\n",
218
+ " sp=bpe_model,\n",
219
+ " subsampling_factor=subsampling_factor,\n",
220
+ " frame_shift_ms=frame_shift_ms,\n",
221
+ " )\n",
222
+ " logging.info(f'confidence {confidence}')\n",
223
+ " logging.info(timestamps)\n",
224
+ " token_ids = get_texts(best_path)\n",
225
+ " return self.format_trs(hyps[0], timestamps[0])\n",
226
+ " \n",
227
+ " def transcribe_file(self, audio_filename):\n",
228
+ " wave=self.read_sound_file_(audio_filename, expected_sample_rate=self.args.sample_rate)\n",
229
+ " \n",
230
+ " trs=self.decode_(wave, self.fbank, self.model, self.args.device, self.args.method, \n",
231
+ " self.args.bpe_model_filename, self.args.num_classes,\n",
232
+ " self.args.min_active_states, self.args.max_active_states, \n",
233
+ " self.args.subsampling_factor, self.args.use_double_scores, \n",
234
+ " self.args.frame_shift_ms, self.args.search_beam, self.args.output_beam)\n",
235
+ " return trs"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "markdown",
240
+ "id": "b1464957-05b6-40f8-a1aa-c58edbed440c",
241
+ "metadata": {},
242
+ "source": [
243
+ "## Example usage"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": 6,
249
+ "id": "50ab7c8e-39b6-4783-8342-e79e91d2417e",
250
+ "metadata": {},
251
+ "outputs": [
252
+ {
253
+ "name": "stderr",
254
+ "output_type": "stream",
255
+ "text": [
256
+ "fatal: not a git repository (or any parent up to mount point /opt)\n",
257
+ "Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).\n",
258
+ "fatal: not a git repository (or any parent up to mount point /opt)\n",
259
+ "Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).\n",
260
+ "fatal: not a git repository (or any parent up to mount point /opt)\n",
261
+ "Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).\n"
262
+ ]
263
+ }
264
+ ],
265
+ "source": [
266
+ "#create transcriber/decoder object\n",
267
+ "#if you want to change parameters (for example model filename) you could create a dict (see class Args attribute names)\n",
268
+ "#and add it to as argument decoder initialization:\n",
269
+ "#conformerCtc3Decoder(get_params() | get_decoding_params() | {'model_filename':'my new model filename'})\n",
270
+ "transcriber=ConformerCtc3Decoder(get_params() | get_decoding_params())"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": 7,
276
+ "id": "8020f371-7584-4f6c-990b-f2c023e24060",
277
+ "metadata": {},
278
+ "outputs": [
279
+ {
280
+ "name": "stdout",
281
+ "output_type": "stream",
282
+ "text": [
283
+ "CPU times: user 4.86 s, sys: 435 ms, total: 5.29 s\n",
284
+ "Wall time: 4.45 s\n"
285
+ ]
286
+ },
287
+ {
288
+ "data": {
289
+ "text/plain": [
290
+ "{'text': 'mina tahaksin homme täna ja homme kui saan all kolm krantsumadiseid veiki panna',\n",
291
+ " 'words': [{'word': 'mina', 'start': 0.8, 'end': 0.84},\n",
292
+ " {'word': 'tahaksin', 'start': 1.0, 'end': 1.32},\n",
293
+ " {'word': 'homme', 'start': 1.48, 'end': 1.76},\n",
294
+ " {'word': 'täna', 'start': 2.08, 'end': 2.12},\n",
295
+ " {'word': 'ja', 'start': 3.72, 'end': 3.76},\n",
296
+ " {'word': 'homme', 'start': 4.16, 'end': 4.44},\n",
297
+ " {'word': 'kui', 'start': 5.96, 'end': 6.0},\n",
298
+ " {'word': 'saan', 'start': 6.52, 'end': 6.84},\n",
299
+ " {'word': 'all', 'start': 7.36, 'end': 7.4},\n",
300
+ " {'word': 'kolm', 'start': 8.32, 'end': 8.36},\n",
301
+ " {'word': 'krantsumadiseid', 'start': 8.68, 'end': 9.72},\n",
302
+ " {'word': 'veiki', 'start': 9.76, 'end': 10.04},\n",
303
+ " {'word': 'panna', 'start': 10.16, 'end': 10.4}]}"
304
+ ]
305
+ },
306
+ "execution_count": 7,
307
+ "metadata": {},
308
+ "output_type": "execute_result"
309
+ }
310
+ ],
311
+ "source": [
312
+ "#transribe audiofile (NB! model assumes sample rate of 16000)\n",
313
+ "%time transcriber.transcribe_file('audio/emt16k.wav')"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "code",
318
+ "execution_count": 10,
319
+ "id": "4d2a480d-f0aa-4474-bfdb-ad298a629ce5",
320
+ "metadata": {},
321
+ "outputs": [
322
+ {
323
+ "name": "stdout",
324
+ "output_type": "stream",
325
+ "text": [
326
+ "CPU times: user 16.2 s, sys: 1.8 s, total: 18 s\n",
327
+ "Wall time: 15.1 s\n"
328
+ ]
329
+ }
330
+ ],
331
+ "source": [
332
+ "%time trs=transcriber.transcribe_file('audio/oden_kypsis16k.wav')"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "execution_count": 11,
338
+ "id": "d3827548-bca0-4409-95bc-9aa8ba377135",
339
+ "metadata": {},
340
+ "outputs": [
341
+ {
342
+ "data": {
343
+ "text/plain": [
344
+ "{'text': 'enamus ajast nagu klikkid neid allserva tekivad need luba küpsiseid mis on nagu ilusti kohati tõlgitud eesti keelde see idee arusaadavamaks ma tean et see on kukis inglise kees ma ei saa sellest ka aru nagu mis asi on kukis on ju ma saan aru et ta vaid minee eest ära luba küpsises tava ei anna noh anna minna ma luban küpssi juhmaoloog okei on ju ma ei tea mis ta teeb lihtsalt selle eestikeelseks tõlk või eesti keelde tõlkimine kui teinud seda nagu arusaadavamaks küpsised kuule kuule veebisaid küsib sinu käest tahad tähendab on okei kui me neid kugiseid kasutame sa mingi ja mida iga mul täiesti savi või noh et et jah',\n",
345
+ " 'words': [{'word': 'enamus', 'start': 3.56, 'end': 3.8},\n",
346
+ " {'word': 'ajast', 'start': 3.8, 'end': 4.04},\n",
347
+ " {'word': 'nagu', 'start': 4.2, 'end': 4.24},\n",
348
+ " {'word': 'klikkid', 'start': 4.72, 'end': 5.12},\n",
349
+ " {'word': 'neid', 'start': 5.16, 'end': 5.2},\n",
350
+ " {'word': 'allserva', 'start': 5.72, 'end': 6.2},\n",
351
+ " {'word': 'tekivad', 'start': 6.32, 'end': 6.64},\n",
352
+ " {'word': 'need', 'start': 7.4, 'end': 7.44},\n",
353
+ " {'word': 'luba', 'start': 7.72, 'end': 8.0},\n",
354
+ " {'word': 'küpsiseid', 'start': 8.08, 'end': 8.64},\n",
355
+ " {'word': 'mis', 'start': 9.68, 'end': 9.72},\n",
356
+ " {'word': 'on', 'start': 9.76, 'end': 9.8},\n",
357
+ " {'word': 'nagu', 'start': 9.92, 'end': 9.96},\n",
358
+ " {'word': 'ilusti', 'start': 10.04, 'end': 10.36},\n",
359
+ " {'word': 'kohati', 'start': 10.4, 'end': 10.68},\n",
360
+ " {'word': 'tõlgitud', 'start': 11.08, 'end': 11.4},\n",
361
+ " {'word': 'eesti', 'start': 11.6, 'end': 11.64},\n",
362
+ " {'word': 'keelde', 'start': 11.8, 'end': 12.08},\n",
363
+ " {'word': 'see', 'start': 12.68, 'end': 12.72},\n",
364
+ " {'word': 'idee', 'start': 12.8, 'end': 13.04},\n",
365
+ " {'word': 'arusaadavamaks', 'start': 13.2, 'end': 13.8},\n",
366
+ " {'word': 'ma', 'start': 13.92, 'end': 13.96},\n",
367
+ " {'word': 'tean', 'start': 14.04, 'end': 14.24},\n",
368
+ " {'word': 'et', 'start': 14.28, 'end': 14.36},\n",
369
+ " {'word': 'see', 'start': 14.4, 'end': 14.44},\n",
370
+ " {'word': 'on', 'start': 14.44, 'end': 14.52},\n",
371
+ " {'word': 'kukis', 'start': 14.56, 'end': 14.92},\n",
372
+ " {'word': 'inglise', 'start': 14.92, 'end': 15.2},\n",
373
+ " {'word': 'kees', 'start': 15.2, 'end': 15.44},\n",
374
+ " {'word': 'ma', 'start': 15.84, 'end': 15.88},\n",
375
+ " {'word': 'ei', 'start': 15.92, 'end': 16.0},\n",
376
+ " {'word': 'saa', 'start': 16.04, 'end': 16.08},\n",
377
+ " {'word': 'sellest', 'start': 16.24, 'end': 16.28},\n",
378
+ " {'word': 'ka', 'start': 16.56, 'end': 16.6},\n",
379
+ " {'word': 'aru', 'start': 16.76, 'end': 16.8},\n",
380
+ " {'word': 'nagu', 'start': 16.96, 'end': 17.0},\n",
381
+ " {'word': 'mis', 'start': 17.12, 'end': 17.16},\n",
382
+ " {'word': 'asi', 'start': 17.28, 'end': 17.32},\n",
383
+ " {'word': 'on', 'start': 17.36, 'end': 17.4},\n",
384
+ " {'word': 'kukis', 'start': 17.48, 'end': 17.8},\n",
385
+ " {'word': 'on', 'start': 17.88, 'end': 17.92},\n",
386
+ " {'word': 'ju', 'start': 17.96, 'end': 18.0},\n",
387
+ " {'word': 'ma', 'start': 18.28, 'end': 18.32},\n",
388
+ " {'word': 'saan', 'start': 18.36, 'end': 18.48},\n",
389
+ " {'word': 'aru', 'start': 18.52, 'end': 18.56},\n",
390
+ " {'word': 'et', 'start': 18.72, 'end': 18.76},\n",
391
+ " {'word': 'ta', 'start': 19.2, 'end': 19.24},\n",
392
+ " {'word': 'vaid', 'start': 19.32, 'end': 19.44},\n",
393
+ " {'word': 'minee', 'start': 19.48, 'end': 19.68},\n",
394
+ " {'word': 'eest', 'start': 19.76, 'end': 19.96},\n",
395
+ " {'word': 'ära', 'start': 20.12, 'end': 20.16},\n",
396
+ " {'word': 'luba', 'start': 21.56, 'end': 21.88},\n",
397
+ " {'word': 'küpsises', 'start': 21.96, 'end': 22.44},\n",
398
+ " {'word': 'tava', 'start': 22.6, 'end': 22.76},\n",
399
+ " {'word': 'ei', 'start': 22.84, 'end': 22.88},\n",
400
+ " {'word': 'anna', 'start': 23.0, 'end': 23.16},\n",
401
+ " {'word': 'noh', 'start': 23.4, 'end': 23.44},\n",
402
+ " {'word': 'anna', 'start': 23.64, 'end': 23.76},\n",
403
+ " {'word': 'minna', 'start': 24.0, 'end': 24.04},\n",
404
+ " {'word': 'ma', 'start': 24.16, 'end': 24.2},\n",
405
+ " {'word': 'luban', 'start': 24.24, 'end': 24.56},\n",
406
+ " {'word': 'küpssi', 'start': 24.64, 'end': 24.92},\n",
407
+ " {'word': 'juhmaoloog', 'start': 25.0, 'end': 25.28},\n",
408
+ " {'word': 'okei', 'start': 25.28, 'end': 25.56},\n",
409
+ " {'word': 'on', 'start': 25.64, 'end': 25.72},\n",
410
+ " {'word': 'ju', 'start': 25.72, 'end': 25.76},\n",
411
+ " {'word': 'ma', 'start': 25.84, 'end': 25.88},\n",
412
+ " {'word': 'ei', 'start': 25.92, 'end': 25.96},\n",
413
+ " {'word': 'tea', 'start': 26.0, 'end': 26.04},\n",
414
+ " {'word': 'mis', 'start': 26.28, 'end': 26.32},\n",
415
+ " {'word': 'ta', 'start': 26.36, 'end': 26.4},\n",
416
+ " {'word': 'teeb', 'start': 26.56, 'end': 26.8},\n",
417
+ " {'word': 'lihtsalt', 'start': 27.04, 'end': 27.08},\n",
418
+ " {'word': 'selle', 'start': 27.24, 'end': 27.28},\n",
419
+ " {'word': 'eestikeelseks', 'start': 28.04, 'end': 28.68},\n",
420
+ " {'word': 'tõlk', 'start': 28.8, 'end': 29.08},\n",
421
+ " {'word': 'või', 'start': 29.16, 'end': 29.2},\n",
422
+ " {'word': 'eesti', 'start': 29.48, 'end': 29.52},\n",
423
+ " {'word': 'keelde', 'start': 29.68, 'end': 30.04},\n",
424
+ " {'word': 'tõlkimine', 'start': 30.2, 'end': 30.68},\n",
425
+ " {'word': 'kui', 'start': 30.8, 'end': 30.84},\n",
426
+ " {'word': 'teinud', 'start': 30.96, 'end': 31.16},\n",
427
+ " {'word': 'seda', 'start': 31.2, 'end': 31.24},\n",
428
+ " {'word': 'nagu', 'start': 31.72, 'end': 31.76},\n",
429
+ " {'word': 'arusaadavamaks', 'start': 31.88, 'end': 32.6},\n",
430
+ " {'word': 'küpsised', 'start': 33.52, 'end': 33.88},\n",
431
+ " {'word': 'kuule', 'start': 36.96, 'end': 37.08},\n",
432
+ " {'word': 'kuule', 'start': 37.32, 'end': 37.44},\n",
433
+ " {'word': 'veebisaid', 'start': 37.8, 'end': 38.28},\n",
434
+ " {'word': 'küsib', 'start': 38.44, 'end': 38.56},\n",
435
+ " {'word': 'sinu', 'start': 38.6, 'end': 38.72},\n",
436
+ " {'word': 'käest', 'start': 38.76, 'end': 39.0},\n",
437
+ " {'word': 'tahad', 'start': 39.52, 'end': 39.72},\n",
438
+ " {'word': 'tähendab', 'start': 40.32, 'end': 40.36},\n",
439
+ " {'word': 'on', 'start': 40.8, 'end': 40.88},\n",
440
+ " {'word': 'okei', 'start': 40.88, 'end': 41.2},\n",
441
+ " {'word': 'kui', 'start': 41.24, 'end': 41.28},\n",
442
+ " {'word': 'me', 'start': 41.36, 'end': 41.4},\n",
443
+ " {'word': 'neid', 'start': 41.6, 'end': 41.64},\n",
444
+ " {'word': 'kugiseid', 'start': 42.2, 'end': 42.64},\n",
445
+ " {'word': 'kasutame', 'start': 42.8, 'end': 43.08},\n",
446
+ " {'word': 'sa', 'start': 43.56, 'end': 43.6},\n",
447
+ " {'word': 'mingi', 'start': 43.8, 'end': 43.84},\n",
448
+ " {'word': 'ja', 'start': 44.04, 'end': 44.08},\n",
449
+ " {'word': 'mida', 'start': 44.28, 'end': 44.32},\n",
450
+ " {'word': 'iga', 'start': 44.44, 'end': 44.48},\n",
451
+ " {'word': 'mul', 'start': 44.56, 'end': 44.6},\n",
452
+ " {'word': 'täiesti', 'start': 44.92, 'end': 44.96},\n",
453
+ " {'word': 'savi', 'start': 45.08, 'end': 45.28},\n",
454
+ " {'word': 'või', 'start': 45.36, 'end': 45.4},\n",
455
+ " {'word': 'noh', 'start': 45.44, 'end': 45.48},\n",
456
+ " {'word': 'et', 'start': 45.6, 'end': 45.64},\n",
457
+ " {'word': 'et', 'start': 47.36, 'end': 47.4},\n",
458
+ " {'word': 'jah', 'start': 47.56, 'end': 47.68}]}"
459
+ ]
460
+ },
461
+ "execution_count": 11,
462
+ "metadata": {},
463
+ "output_type": "execute_result"
464
+ }
465
+ ],
466
+ "source": [
467
+ "trs"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "code",
472
+ "execution_count": null,
473
+ "id": "ea3b25b7-a1f9-4b21-911d-35159c5f3009",
474
+ "metadata": {},
475
+ "outputs": [],
476
+ "source": []
477
+ }
478
+ ],
479
+ "metadata": {
480
+ "kernelspec": {
481
+ "display_name": "Python 3 (ipykernel)",
482
+ "language": "python",
483
+ "name": "python3"
484
+ },
485
+ "language_info": {
486
+ "codemirror_mode": {
487
+ "name": "ipython",
488
+ "version": 3
489
+ },
490
+ "file_extension": ".py",
491
+ "mimetype": "text/x-python",
492
+ "name": "python",
493
+ "nbconvert_exporter": "python",
494
+ "pygments_lexer": "ipython3",
495
+ "version": "3.9.16"
496
+ }
497
+ },
498
+ "nbformat": 4,
499
+ "nbformat_minor": 5
500
+ }
err2020/data/lang_bpe_500/bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14afda3d7b1a9b2d07ca4f55bdf2d9d7424bb795068cac61107bc2b58a26b7fd
3
+ size 245129
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ lhotse
run.bat ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ set APP_PATH=%cd%
2
+
3
+ docker stop icefall_run
4
+ docker rm icefall_run
5
+ docker run -it --rm ^
6
+ -p 8888:8888 ^
7
+ -v %APP_PATH%:/opt/notebooks ^
8
+ --name icefall_run ^
9
+ icefall
run.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ APP_PATH=$(pwd)
4
+
5
+ docker stop icefall_run
6
+ docker rm icefall_run
7
+ docker run -it --rm \
8
+ -p 8888:8888 \
9
+ -v "$APP_PATH":/opt/notebooks \
10
+ --name icefall_run \
11
+ icefall