Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +160 -0
- app.py +180 -0
- config.yaml +266 -0
- espnet/__init__.py +8 -0
- espnet/asr/__init__.py +1 -0
- espnet/asr/asr_mix_utils.py +187 -0
- espnet/asr/asr_utils.py +1024 -0
- espnet/asr/chainer_backend/__init__.py +1 -0
- espnet/asr/chainer_backend/asr.py +575 -0
- espnet/asr/pytorch_backend/__init__.py +1 -0
- espnet/asr/pytorch_backend/asr.py +1500 -0
- espnet/asr/pytorch_backend/asr_init.py +282 -0
- espnet/asr/pytorch_backend/asr_mix.py +654 -0
- espnet/asr/pytorch_backend/recog.py +152 -0
- espnet/bin/__init__.py +1 -0
- espnet/bin/asr_align.py +348 -0
- espnet/bin/asr_enhance.py +191 -0
- espnet/bin/asr_recog.py +363 -0
- espnet/bin/asr_train.py +644 -0
- espnet/bin/lm_train.py +288 -0
- espnet/bin/mt_train.py +480 -0
- espnet/bin/mt_trans.py +186 -0
- espnet/bin/st_train.py +550 -0
- espnet/bin/st_trans.py +183 -0
- espnet/bin/tts_decode.py +180 -0
- espnet/bin/tts_train.py +359 -0
- espnet/bin/vc_decode.py +174 -0
- espnet/bin/vc_train.py +368 -0
- espnet/lm/__init__.py +1 -0
- espnet/lm/chainer_backend/__init__.py +1 -0
- espnet/lm/chainer_backend/extlm.py +199 -0
- espnet/lm/chainer_backend/lm.py +484 -0
- espnet/lm/lm_utils.py +293 -0
- espnet/lm/pytorch_backend/__init__.py +1 -0
- espnet/lm/pytorch_backend/extlm.py +218 -0
- espnet/lm/pytorch_backend/lm.py +410 -0
- espnet/mt/__init__.py +1 -0
- espnet/mt/mt_utils.py +83 -0
- espnet/mt/pytorch_backend/__init__.py +1 -0
- espnet/mt/pytorch_backend/mt.py +600 -0
- espnet/nets/__init__.py +1 -0
- espnet/nets/asr_interface.py +172 -0
- espnet/nets/batch_beam_search.py +348 -0
- espnet/nets/batch_beam_search_online_sim.py +270 -0
- espnet/nets/beam_search.py +512 -0
- espnet/nets/beam_search_transducer.py +629 -0
- espnet/nets/chainer_backend/__init__.py +1 -0
- espnet/nets/chainer_backend/asr_interface.py +29 -0
- espnet/nets/chainer_backend/ctc.py +184 -0
- espnet/nets/chainer_backend/deterministic_embed_id.py +253 -0
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
app.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from espnet2.bin.tts_inference import Text2Speech
|
2 |
+
import torch
|
3 |
+
from parallel_wavegan.utils import download_pretrained_model, load_model
|
4 |
+
from phonemizer import phonemize
|
5 |
+
from phonemizer.separator import Separator
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
s = Separator(word=None, phone=" ")
|
9 |
+
config_path = "config.yaml"
|
10 |
+
model_path = "model.pth"
|
11 |
+
|
12 |
+
vocoder_tag = "ljspeech_parallel_wavegan.v3"
|
13 |
+
|
14 |
+
vocoder = load_model(download_pretrained_model(vocoder_tag)).to("cpu").eval()
|
15 |
+
vocoder.remove_weight_norm()
|
16 |
+
|
17 |
+
global_styles = {
|
18 |
+
"Style 1": torch.load("style1.pt"),
|
19 |
+
"Style 2": torch.load("style2.pt"),
|
20 |
+
"Style 3": torch.load("style3.pt"),
|
21 |
+
"Style 4": torch.load("style4.pt"),
|
22 |
+
"Style 5": torch.load("style5.pt"),
|
23 |
+
"Style 6": torch.load("style6.pt"),
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def inference(text, global_style, alpha, prev_fg_inds, input_fg_inds):
|
28 |
+
with torch.no_grad():
|
29 |
+
text2speech = Text2Speech(
|
30 |
+
config_path,
|
31 |
+
model_path,
|
32 |
+
device="cpu",
|
33 |
+
# Only for Tacotron 2
|
34 |
+
threshold=0.5,
|
35 |
+
minlenratio=0.0,
|
36 |
+
maxlenratio=10.0,
|
37 |
+
use_att_constraint=False,
|
38 |
+
backward_window=1,
|
39 |
+
forward_window=3,
|
40 |
+
# Only for FastSpeech & FastSpeech2
|
41 |
+
speed_control_alpha=alpha,
|
42 |
+
)
|
43 |
+
text2speech.spc2wav = None # Disable griffin-lim
|
44 |
+
|
45 |
+
style_emb = torch.flatten(global_styles[global_style])
|
46 |
+
|
47 |
+
phoneme_string = phonemize(
|
48 |
+
text, language="mb-us1", backend="espeak-mbrola", separator=s
|
49 |
+
)
|
50 |
+
phonemes = phoneme_string.split(" ")
|
51 |
+
|
52 |
+
max_edit_index = -1
|
53 |
+
for i in range(len(input_fg_inds) - 1, -1, -1):
|
54 |
+
if input_fg_inds[i] != "":
|
55 |
+
max_edit_index = i
|
56 |
+
break
|
57 |
+
|
58 |
+
if max_edit_index == -1:
|
59 |
+
_, c, _, _, _, _, _, output_fg_inds = text2speech(
|
60 |
+
phoneme_string, ref_embs=style_emb
|
61 |
+
)
|
62 |
+
|
63 |
+
else:
|
64 |
+
input_fg_inds_int_list = []
|
65 |
+
for i in range(max_edit_index + 1):
|
66 |
+
if input_fg_inds[i] != "":
|
67 |
+
input_fg_inds_int_list.append(int(input_fg_inds[i]))
|
68 |
+
else:
|
69 |
+
input_fg_inds_int_list.append(prev_fg_inds[i][1])
|
70 |
+
input_fg_inds = input_fg_inds_int_list
|
71 |
+
|
72 |
+
prev_fg_inds_list = [[[row[1], row[2], row[3]] for row in prev_fg_inds]]
|
73 |
+
prev_fg_inds = torch.tensor(prev_fg_inds_list, dtype=torch.int64)
|
74 |
+
|
75 |
+
fg_inds = torch.tensor(input_fg_inds_int_list).unsqueeze(0)
|
76 |
+
_, c, _, _, _, _, _, part_output_fg_inds = text2speech(
|
77 |
+
phoneme_string, ref_embs=style_emb, fg_inds=fg_inds
|
78 |
+
)
|
79 |
+
|
80 |
+
prev_fg_inds[0, max_edit_index + 1 :, :] = part_output_fg_inds[0]
|
81 |
+
output_fg_inds = prev_fg_inds
|
82 |
+
|
83 |
+
output_fg_inds_list = output_fg_inds.tolist()[0]
|
84 |
+
padded_phonemes = ["", *phonemes]
|
85 |
+
dataframe_values = [
|
86 |
+
[phoneme, *fgs]
|
87 |
+
for phoneme, fgs in zip(padded_phonemes, output_fg_inds_list)
|
88 |
+
]
|
89 |
+
selected_inds = [
|
90 |
+
[input_fg_inds[i]] if i < len(input_fg_inds) else [""]
|
91 |
+
for i in range(len(padded_phonemes))
|
92 |
+
]
|
93 |
+
wav = vocoder.inference(c)
|
94 |
+
|
95 |
+
return [
|
96 |
+
(22050, wav.view(-1).cpu().numpy()),
|
97 |
+
dataframe_values,
|
98 |
+
selected_inds,
|
99 |
+
]
|
100 |
+
|
101 |
+
|
102 |
+
demo = gr.Blocks()
|
103 |
+
|
104 |
+
with demo:
|
105 |
+
gr.Markdown(
|
106 |
+
"""
|
107 |
+
|
108 |
+
# ConEx Demo
|
109 |
+
|
110 |
+
This demo shows the capabilities of ConEx, a model for **Con**trollable **Ex**pressive speech synthesis.
|
111 |
+
ConEx allows you to generate speech in a certain speaking style, and gives you the ability to edit the prosody* of the generated speech at a fine level.
|
112 |
+
We proposed ConEx in our paper titled ["Interactive Multi-Level Prosody Control for Expressive Speech Synthesis"](https://jessa.github.io/assets/pdf/cornille2022icassp.pdf), published in proceedings of the IEEE International Conference on Acoustics, Speech, and Signal Processing (ICASSP) 2022.
|
113 |
+
|
114 |
+
To convert text to speech: input some text, choose the desired speaking style, set the duration factor (higher = slower speech), and press "Generate speech".
|
115 |
+
|
116 |
+
**prosody refers to speech characteristics such as intonation, stress, rhythm*
|
117 |
+
"""
|
118 |
+
)
|
119 |
+
|
120 |
+
with gr.Row():
|
121 |
+
text_input = gr.Textbox(
|
122 |
+
label="Input text",
|
123 |
+
lines=4,
|
124 |
+
placeholder="E.g. I didn't say he stole the money",
|
125 |
+
)
|
126 |
+
|
127 |
+
with gr.Column():
|
128 |
+
global_style_dropdown = gr.Dropdown(
|
129 |
+
["Style 1", "Style 2", "Style 3", "Style 4", "Style 5", "Style 6"],
|
130 |
+
value="Style 1",
|
131 |
+
label="Global speaking style",
|
132 |
+
)
|
133 |
+
alpha_slider = gr.Slider(
|
134 |
+
0.1, 2, value=1, step=0.1, label="Alpha (duration factor)"
|
135 |
+
)
|
136 |
+
|
137 |
+
audio = gr.Audio()
|
138 |
+
with gr.Row():
|
139 |
+
button = gr.Button("Generate Speech")
|
140 |
+
|
141 |
+
gr.Markdown(
|
142 |
+
"""
|
143 |
+
|
144 |
+
### Fine-grained prosody editor
|
145 |
+
Once you've generated some speech, the following table will show the id of the prosody embedding used for each phoneme.
|
146 |
+
A prosody embedding determines the prosody of the phoneme.
|
147 |
+
The table not only shows the prosody embeddings that are used by default (the top predictions), but also two more likely prosody embeddings.
|
148 |
+
|
149 |
+
In order to change the prosody of a phoneme, write a new prosody embedding id in the "Chosen prosody embeddings" column and press "Generate speech" again.
|
150 |
+
You can use any number from 0-31, but the 2nd and 3rd predictions are more likely to give a fitting prosody.
|
151 |
+
Based on your edit, new prosody embeddings will be generated for the phonemes after the edit.
|
152 |
+
Thus, you can iteratively change the prosody by starting from the beginning of the utterance and working your through the utterance, making edits as you see fit.
|
153 |
+
The prosody embeddings before your edit will remain the same as before, and will be copied to the "Chosen prosody embeddings" column.
|
154 |
+
"""
|
155 |
+
)
|
156 |
+
|
157 |
+
with gr.Row():
|
158 |
+
phoneme_preds_df = gr.Dataframe(
|
159 |
+
headers=["Phoneme", "🥇 Top pred", "🥈 2nd pred", "🥉 3rd pred"],
|
160 |
+
type="array",
|
161 |
+
col_count=(4, "static"),
|
162 |
+
)
|
163 |
+
phoneme_edits_df = gr.Dataframe(
|
164 |
+
headers=["Chosen prosody embeddings"], type="array", col_count=(1, "static")
|
165 |
+
)
|
166 |
+
|
167 |
+
button.click(
|
168 |
+
inference,
|
169 |
+
inputs=[
|
170 |
+
text_input,
|
171 |
+
global_style_dropdown,
|
172 |
+
alpha_slider,
|
173 |
+
phoneme_preds_df,
|
174 |
+
phoneme_edits_df,
|
175 |
+
],
|
176 |
+
outputs=[audio, phoneme_preds_df, phoneme_edits_df],
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
demo.launch()
|
config.yaml
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
config: conf/ar_prior_train.yaml
|
2 |
+
print_config: false
|
3 |
+
log_level: INFO
|
4 |
+
dry_run: false
|
5 |
+
iterator_type: sequence
|
6 |
+
output_dir: exp/tts_finetune_ar_prior
|
7 |
+
ngpu: 1
|
8 |
+
seed: 0
|
9 |
+
num_workers: 1
|
10 |
+
num_att_plot: 3
|
11 |
+
dist_backend: nccl
|
12 |
+
dist_init_method: env://
|
13 |
+
dist_world_size: null
|
14 |
+
dist_rank: null
|
15 |
+
local_rank: 0
|
16 |
+
dist_master_addr: null
|
17 |
+
dist_master_port: null
|
18 |
+
dist_launcher: null
|
19 |
+
multiprocessing_distributed: false
|
20 |
+
unused_parameters: false
|
21 |
+
sharded_ddp: false
|
22 |
+
cudnn_enabled: true
|
23 |
+
cudnn_benchmark: false
|
24 |
+
cudnn_deterministic: true
|
25 |
+
collect_stats: false
|
26 |
+
write_collected_feats: false
|
27 |
+
max_epoch: 500
|
28 |
+
patience: null
|
29 |
+
val_scheduler_criterion:
|
30 |
+
- valid
|
31 |
+
- loss
|
32 |
+
early_stopping_criterion:
|
33 |
+
- valid
|
34 |
+
- loss
|
35 |
+
- min
|
36 |
+
best_model_criterion:
|
37 |
+
- - valid
|
38 |
+
- loss
|
39 |
+
- min
|
40 |
+
- - train
|
41 |
+
- loss
|
42 |
+
- min
|
43 |
+
keep_nbest_models: 5
|
44 |
+
grad_clip: 1.0
|
45 |
+
grad_clip_type: 2.0
|
46 |
+
grad_noise: false
|
47 |
+
accum_grad: 8
|
48 |
+
no_forward_run: false
|
49 |
+
resume: true
|
50 |
+
train_dtype: float32
|
51 |
+
use_amp: false
|
52 |
+
log_interval: null
|
53 |
+
use_tensorboard: true
|
54 |
+
use_wandb: false
|
55 |
+
wandb_project: null
|
56 |
+
wandb_id: null
|
57 |
+
detect_anomaly: false
|
58 |
+
pretrain_path: null
|
59 |
+
init_param:
|
60 |
+
- /data/leuven/339/vsc33942/espnet-mirror/egs2/acapela_blizzard/tts1/exp/tts_train_raw_phn_none/valid.loss.best.pth:::tts.prosody_encoder.ar_prior
|
61 |
+
freeze_param:
|
62 |
+
- encoder.,prosody_encoder.ref_encoder.,prosody_encoder.fg_encoder.,prosody_encoder.global_encoder.,prosody_encoder.global_projection.,prosody_encoder.vq_layer.,prosody_encoder.qfg_projection,duration_predictor.,length_regulator,decoder.,feat_out,postnet
|
63 |
+
num_iters_per_epoch: 50
|
64 |
+
batch_size: 20
|
65 |
+
valid_batch_size: null
|
66 |
+
batch_bins: 3000000
|
67 |
+
valid_batch_bins: null
|
68 |
+
train_shape_file:
|
69 |
+
- exp/tts_stats_raw_phn_none/train/text_shape.phn
|
70 |
+
- exp/tts_stats_raw_phn_none/train/speech_shape
|
71 |
+
valid_shape_file:
|
72 |
+
- exp/tts_stats_raw_phn_none/valid/text_shape.phn
|
73 |
+
- exp/tts_stats_raw_phn_none/valid/speech_shape
|
74 |
+
batch_type: numel
|
75 |
+
valid_batch_type: null
|
76 |
+
fold_length:
|
77 |
+
- 150
|
78 |
+
- 204800
|
79 |
+
sort_in_batch: descending
|
80 |
+
sort_batch: descending
|
81 |
+
multiple_iterator: false
|
82 |
+
chunk_length: 500
|
83 |
+
chunk_shift_ratio: 0.5
|
84 |
+
num_cache_chunks: 1024
|
85 |
+
train_data_path_and_name_and_type:
|
86 |
+
- - dump/raw/tr_no_dev/text
|
87 |
+
- text
|
88 |
+
- text
|
89 |
+
- - data/durations/tr_no_dev/durations
|
90 |
+
- durations
|
91 |
+
- text_int
|
92 |
+
- - dump/raw/tr_no_dev/wav.scp
|
93 |
+
- speech
|
94 |
+
- sound
|
95 |
+
valid_data_path_and_name_and_type:
|
96 |
+
- - dump/raw/dev/text
|
97 |
+
- text
|
98 |
+
- text
|
99 |
+
- - data/durations/dev/durations
|
100 |
+
- durations
|
101 |
+
- text_int
|
102 |
+
- - dump/raw/dev/wav.scp
|
103 |
+
- speech
|
104 |
+
- sound
|
105 |
+
allow_variable_data_keys: false
|
106 |
+
max_cache_size: 0.0
|
107 |
+
max_cache_fd: 32
|
108 |
+
valid_max_cache_size: null
|
109 |
+
optim: adam
|
110 |
+
optim_conf:
|
111 |
+
lr: 1.0
|
112 |
+
scheduler: noamlr
|
113 |
+
scheduler_conf:
|
114 |
+
model_size: 384
|
115 |
+
warmup_steps: 4000
|
116 |
+
token_list:
|
117 |
+
- <blank>
|
118 |
+
- <unk>
|
119 |
+
- n
|
120 |
+
- '@'
|
121 |
+
- t
|
122 |
+
- _
|
123 |
+
- s
|
124 |
+
- I
|
125 |
+
- r
|
126 |
+
- d
|
127 |
+
- l
|
128 |
+
- m
|
129 |
+
- i
|
130 |
+
- '{'
|
131 |
+
- z
|
132 |
+
- D
|
133 |
+
- w
|
134 |
+
- r=
|
135 |
+
- f
|
136 |
+
- v
|
137 |
+
- E1
|
138 |
+
- b
|
139 |
+
- t_h
|
140 |
+
- h
|
141 |
+
- V
|
142 |
+
- u
|
143 |
+
- k
|
144 |
+
- I1
|
145 |
+
- '{1'
|
146 |
+
- k_h
|
147 |
+
- N
|
148 |
+
- EI1
|
149 |
+
- V1
|
150 |
+
- O1
|
151 |
+
- AI
|
152 |
+
- H
|
153 |
+
- S
|
154 |
+
- p_h
|
155 |
+
- '@U1'
|
156 |
+
- i1
|
157 |
+
- g
|
158 |
+
- AI1
|
159 |
+
- j
|
160 |
+
- O
|
161 |
+
- p
|
162 |
+
- u1
|
163 |
+
- r=1
|
164 |
+
- tS
|
165 |
+
- Or
|
166 |
+
- '4'
|
167 |
+
- A
|
168 |
+
- Or1
|
169 |
+
- E
|
170 |
+
- dZ
|
171 |
+
- T
|
172 |
+
- aU1
|
173 |
+
- U
|
174 |
+
- Er1
|
175 |
+
- '@U'
|
176 |
+
- U1
|
177 |
+
- Ar1
|
178 |
+
- Er
|
179 |
+
- aU
|
180 |
+
- EI
|
181 |
+
- ir1
|
182 |
+
- l=
|
183 |
+
- OI1
|
184 |
+
- Ar
|
185 |
+
- Ur1
|
186 |
+
- n=
|
187 |
+
- A1
|
188 |
+
- Z
|
189 |
+
- '?'
|
190 |
+
- ir
|
191 |
+
- Ur
|
192 |
+
- OI
|
193 |
+
- <sos/eos>
|
194 |
+
odim: null
|
195 |
+
model_conf: {}
|
196 |
+
use_preprocessor: true
|
197 |
+
token_type: phn
|
198 |
+
bpemodel: null
|
199 |
+
non_linguistic_symbols: null
|
200 |
+
cleaner: null
|
201 |
+
g2p: null
|
202 |
+
feats_extract: fbank
|
203 |
+
feats_extract_conf:
|
204 |
+
fs: 22050
|
205 |
+
fmin: 80
|
206 |
+
fmax: 7600
|
207 |
+
n_mels: 80
|
208 |
+
hop_length: 256
|
209 |
+
n_fft: 1024
|
210 |
+
win_length: null
|
211 |
+
normalize: global_mvn
|
212 |
+
normalize_conf:
|
213 |
+
stats_file: feats_stats.npz
|
214 |
+
tts: fastespeech
|
215 |
+
tts_conf:
|
216 |
+
adim: 128
|
217 |
+
aheads: 2
|
218 |
+
elayers: 4
|
219 |
+
eunits: 1536
|
220 |
+
dlayers: 4
|
221 |
+
dunits: 1536
|
222 |
+
positionwise_layer_type: conv1d
|
223 |
+
positionwise_conv_kernel_size: 3
|
224 |
+
duration_predictor_layers: 2
|
225 |
+
duration_predictor_chans: 128
|
226 |
+
duration_predictor_kernel_size: 3
|
227 |
+
duration_predictor_dropout_rate: 0.2
|
228 |
+
postnet_layers: 5
|
229 |
+
postnet_filts: 5
|
230 |
+
postnet_chans: 256
|
231 |
+
use_masking: true
|
232 |
+
use_scaled_pos_enc: true
|
233 |
+
encoder_normalize_before: true
|
234 |
+
decoder_normalize_before: true
|
235 |
+
reduction_factor: 1
|
236 |
+
init_type: xavier_uniform
|
237 |
+
init_enc_alpha: 1.0
|
238 |
+
init_dec_alpha: 1.0
|
239 |
+
transformer_enc_dropout_rate: 0.2
|
240 |
+
transformer_enc_positional_dropout_rate: 0.2
|
241 |
+
transformer_enc_attn_dropout_rate: 0.2
|
242 |
+
transformer_dec_dropout_rate: 0.2
|
243 |
+
transformer_dec_positional_dropout_rate: 0.2
|
244 |
+
transformer_dec_attn_dropout_rate: 0.2
|
245 |
+
ref_enc_conv_layers: 2
|
246 |
+
ref_enc_conv_kernel_size: 3
|
247 |
+
ref_enc_conv_stride: 2
|
248 |
+
ref_enc_gru_layers: 1
|
249 |
+
ref_enc_gru_units: 32
|
250 |
+
ref_emb_integration_type: add
|
251 |
+
prosody_num_embs: 32
|
252 |
+
prosody_hidden_dim: 3
|
253 |
+
prosody_emb_integration_type: add
|
254 |
+
pitch_extract: null
|
255 |
+
pitch_extract_conf: {}
|
256 |
+
pitch_normalize: null
|
257 |
+
pitch_normalize_conf: {}
|
258 |
+
energy_extract: null
|
259 |
+
energy_extract_conf: {}
|
260 |
+
energy_normalize: null
|
261 |
+
energy_normalize_conf: {}
|
262 |
+
required:
|
263 |
+
- output_dir
|
264 |
+
- token_list
|
265 |
+
version: 0.9.9
|
266 |
+
distributed: false
|
espnet/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Initialize espnet package."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
dirname = os.path.dirname(__file__)
|
6 |
+
version_file = os.path.join(dirname, "version.txt")
|
7 |
+
with open(version_file, "r") as f:
|
8 |
+
__version__ = f.read().strip()
|
espnet/asr/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/asr/asr_mix_utils.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
This script is used to provide utility functions designed for multi-speaker ASR.
|
5 |
+
|
6 |
+
Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
7 |
+
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
8 |
+
|
9 |
+
Most functions can be directly used as in asr_utils.py:
|
10 |
+
CompareValueTrigger, restore_snapshot, adadelta_eps_decay, chainer_load,
|
11 |
+
torch_snapshot, torch_save, torch_resume, AttributeDict, get_model_conf.
|
12 |
+
|
13 |
+
"""
|
14 |
+
|
15 |
+
import copy
|
16 |
+
import logging
|
17 |
+
import os
|
18 |
+
|
19 |
+
from chainer.training import extension
|
20 |
+
|
21 |
+
import matplotlib
|
22 |
+
|
23 |
+
from espnet.asr.asr_utils import parse_hypothesis
|
24 |
+
|
25 |
+
|
26 |
+
matplotlib.use("Agg")
|
27 |
+
|
28 |
+
|
29 |
+
# * -------------------- chainer extension related -------------------- *
|
30 |
+
class PlotAttentionReport(extension.Extension):
|
31 |
+
"""Plot attention reporter.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
att_vis_fn (espnet.nets.*_backend.e2e_asr.calculate_all_attentions):
|
35 |
+
Function of attention visualization.
|
36 |
+
data (list[tuple(str, dict[str, dict[str, Any]])]): List json utt key items.
|
37 |
+
outdir (str): Directory to save figures.
|
38 |
+
converter (espnet.asr.*_backend.asr.CustomConverter):
|
39 |
+
CustomConverter object. Function to convert data.
|
40 |
+
device (torch.device): The destination device to send tensor.
|
41 |
+
reverse (bool): If True, input and output length are reversed.
|
42 |
+
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self, att_vis_fn, data, outdir, converter, device, reverse=False):
|
46 |
+
"""Initialize PlotAttentionReport."""
|
47 |
+
self.att_vis_fn = att_vis_fn
|
48 |
+
self.data = copy.deepcopy(data)
|
49 |
+
self.outdir = outdir
|
50 |
+
self.converter = converter
|
51 |
+
self.device = device
|
52 |
+
self.reverse = reverse
|
53 |
+
if not os.path.exists(self.outdir):
|
54 |
+
os.makedirs(self.outdir)
|
55 |
+
|
56 |
+
def __call__(self, trainer):
|
57 |
+
"""Plot and save imaged matrix of att_ws."""
|
58 |
+
att_ws_sd = self.get_attention_weights()
|
59 |
+
for ns, att_ws in enumerate(att_ws_sd):
|
60 |
+
for idx, att_w in enumerate(att_ws):
|
61 |
+
filename = "%s/%s.ep.{.updater.epoch}.output%d.png" % (
|
62 |
+
self.outdir,
|
63 |
+
self.data[idx][0],
|
64 |
+
ns + 1,
|
65 |
+
)
|
66 |
+
att_w = self.get_attention_weight(idx, att_w, ns)
|
67 |
+
self._plot_and_save_attention(att_w, filename.format(trainer))
|
68 |
+
|
69 |
+
def log_attentions(self, logger, step):
|
70 |
+
"""Add image files of attention matrix to tensorboard."""
|
71 |
+
att_ws_sd = self.get_attention_weights()
|
72 |
+
for ns, att_ws in enumerate(att_ws_sd):
|
73 |
+
for idx, att_w in enumerate(att_ws):
|
74 |
+
att_w = self.get_attention_weight(idx, att_w, ns)
|
75 |
+
plot = self.draw_attention_plot(att_w)
|
76 |
+
logger.add_figure("%s" % (self.data[idx][0]), plot.gcf(), step)
|
77 |
+
plot.clf()
|
78 |
+
|
79 |
+
def get_attention_weights(self):
|
80 |
+
"""Return attention weights.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
arr_ws_sd (numpy.ndarray): attention weights. It's shape would be
|
84 |
+
differ from bachend.dtype=float
|
85 |
+
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax). 2)
|
86 |
+
other case => (B, Lmax, Tmax).
|
87 |
+
* chainer-> attention weights (B, Lmax, Tmax).
|
88 |
+
|
89 |
+
"""
|
90 |
+
batch = self.converter([self.converter.transform(self.data)], self.device)
|
91 |
+
att_ws_sd = self.att_vis_fn(*batch)
|
92 |
+
return att_ws_sd
|
93 |
+
|
94 |
+
def get_attention_weight(self, idx, att_w, spkr_idx):
|
95 |
+
"""Transform attention weight in regard to self.reverse."""
|
96 |
+
if self.reverse:
|
97 |
+
dec_len = int(self.data[idx][1]["input"][0]["shape"][0])
|
98 |
+
enc_len = int(self.data[idx][1]["output"][spkr_idx]["shape"][0])
|
99 |
+
else:
|
100 |
+
dec_len = int(self.data[idx][1]["output"][spkr_idx]["shape"][0])
|
101 |
+
enc_len = int(self.data[idx][1]["input"][0]["shape"][0])
|
102 |
+
if len(att_w.shape) == 3:
|
103 |
+
att_w = att_w[:, :dec_len, :enc_len]
|
104 |
+
else:
|
105 |
+
att_w = att_w[:dec_len, :enc_len]
|
106 |
+
return att_w
|
107 |
+
|
108 |
+
def draw_attention_plot(self, att_w):
|
109 |
+
"""Visualize attention weights matrix.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
att_w(Tensor): Attention weight matrix.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
matplotlib.pyplot: pyplot object with attention matrix image.
|
116 |
+
|
117 |
+
"""
|
118 |
+
import matplotlib.pyplot as plt
|
119 |
+
|
120 |
+
if len(att_w.shape) == 3:
|
121 |
+
for h, aw in enumerate(att_w, 1):
|
122 |
+
plt.subplot(1, len(att_w), h)
|
123 |
+
plt.imshow(aw, aspect="auto")
|
124 |
+
plt.xlabel("Encoder Index")
|
125 |
+
plt.ylabel("Decoder Index")
|
126 |
+
else:
|
127 |
+
plt.imshow(att_w, aspect="auto")
|
128 |
+
plt.xlabel("Encoder Index")
|
129 |
+
plt.ylabel("Decoder Index")
|
130 |
+
plt.tight_layout()
|
131 |
+
return plt
|
132 |
+
|
133 |
+
def _plot_and_save_attention(self, att_w, filename):
|
134 |
+
plt = self.draw_attention_plot(att_w)
|
135 |
+
plt.savefig(filename)
|
136 |
+
plt.close()
|
137 |
+
|
138 |
+
|
139 |
+
def add_results_to_json(js, nbest_hyps_sd, char_list):
|
140 |
+
"""Add N-best results to json.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
js (dict[str, Any]): Groundtruth utterance dict.
|
144 |
+
nbest_hyps_sd (list[dict[str, Any]]):
|
145 |
+
List of hypothesis for multi_speakers (# Utts x # Spkrs).
|
146 |
+
char_list (list[str]): List of characters.
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
dict[str, Any]: N-best results added utterance dict.
|
150 |
+
|
151 |
+
"""
|
152 |
+
# copy old json info
|
153 |
+
new_js = dict()
|
154 |
+
new_js["utt2spk"] = js["utt2spk"]
|
155 |
+
num_spkrs = len(nbest_hyps_sd)
|
156 |
+
new_js["output"] = []
|
157 |
+
|
158 |
+
for ns in range(num_spkrs):
|
159 |
+
tmp_js = []
|
160 |
+
nbest_hyps = nbest_hyps_sd[ns]
|
161 |
+
|
162 |
+
for n, hyp in enumerate(nbest_hyps, 1):
|
163 |
+
# parse hypothesis
|
164 |
+
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
|
165 |
+
|
166 |
+
# copy ground-truth
|
167 |
+
out_dic = dict(js["output"][ns].items())
|
168 |
+
|
169 |
+
# update name
|
170 |
+
out_dic["name"] += "[%d]" % n
|
171 |
+
|
172 |
+
# add recognition results
|
173 |
+
out_dic["rec_text"] = rec_text
|
174 |
+
out_dic["rec_token"] = rec_token
|
175 |
+
out_dic["rec_tokenid"] = rec_tokenid
|
176 |
+
out_dic["score"] = score
|
177 |
+
|
178 |
+
# add to list of N-best result dicts
|
179 |
+
tmp_js.append(out_dic)
|
180 |
+
|
181 |
+
# show 1-best result
|
182 |
+
if n == 1:
|
183 |
+
logging.info("groundtruth: %s" % out_dic["text"])
|
184 |
+
logging.info("prediction : %s" % out_dic["rec_text"])
|
185 |
+
|
186 |
+
new_js["output"].append(tmp_js)
|
187 |
+
return new_js
|
espnet/asr/asr_utils.py
ADDED
@@ -0,0 +1,1024 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import copy
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import shutil
|
10 |
+
import tempfile
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
|
15 |
+
|
16 |
+
# * -------------------- training iterator related -------------------- *
|
17 |
+
|
18 |
+
|
19 |
+
class CompareValueTrigger(object):
|
20 |
+
"""Trigger invoked when key value getting bigger or lower than before.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
key (str) : Key of value.
|
24 |
+
compare_fn ((float, float) -> bool) : Function to compare the values.
|
25 |
+
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, key, compare_fn, trigger=(1, "epoch")):
|
30 |
+
from chainer import training
|
31 |
+
|
32 |
+
self._key = key
|
33 |
+
self._best_value = None
|
34 |
+
self._interval_trigger = training.util.get_trigger(trigger)
|
35 |
+
self._init_summary()
|
36 |
+
self._compare_fn = compare_fn
|
37 |
+
|
38 |
+
def __call__(self, trainer):
|
39 |
+
"""Get value related to the key and compare with current value."""
|
40 |
+
observation = trainer.observation
|
41 |
+
summary = self._summary
|
42 |
+
key = self._key
|
43 |
+
if key in observation:
|
44 |
+
summary.add({key: observation[key]})
|
45 |
+
|
46 |
+
if not self._interval_trigger(trainer):
|
47 |
+
return False
|
48 |
+
|
49 |
+
stats = summary.compute_mean()
|
50 |
+
value = float(stats[key]) # copy to CPU
|
51 |
+
self._init_summary()
|
52 |
+
|
53 |
+
if self._best_value is None:
|
54 |
+
# initialize best value
|
55 |
+
self._best_value = value
|
56 |
+
return False
|
57 |
+
elif self._compare_fn(self._best_value, value):
|
58 |
+
return True
|
59 |
+
else:
|
60 |
+
self._best_value = value
|
61 |
+
return False
|
62 |
+
|
63 |
+
def _init_summary(self):
|
64 |
+
import chainer
|
65 |
+
|
66 |
+
self._summary = chainer.reporter.DictSummary()
|
67 |
+
|
68 |
+
|
69 |
+
try:
|
70 |
+
from chainer.training import extension
|
71 |
+
except ImportError:
|
72 |
+
PlotAttentionReport = None
|
73 |
+
else:
|
74 |
+
|
75 |
+
class PlotAttentionReport(extension.Extension):
|
76 |
+
"""Plot attention reporter.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
|
80 |
+
Function of attention visualization.
|
81 |
+
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
|
82 |
+
outdir (str): Directory to save figures.
|
83 |
+
converter (espnet.asr.*_backend.asr.CustomConverter):
|
84 |
+
Function to convert data.
|
85 |
+
device (int | torch.device): Device.
|
86 |
+
reverse (bool): If True, input and output length are reversed.
|
87 |
+
ikey (str): Key to access input
|
88 |
+
(for ASR/ST ikey="input", for MT ikey="output".)
|
89 |
+
iaxis (int): Dimension to access input
|
90 |
+
(for ASR/ST iaxis=0, for MT iaxis=1.)
|
91 |
+
okey (str): Key to access output
|
92 |
+
(for ASR/ST okey="input", MT okay="output".)
|
93 |
+
oaxis (int): Dimension to access output
|
94 |
+
(for ASR/ST oaxis=0, for MT oaxis=0.)
|
95 |
+
subsampling_factor (int): subsampling factor in encoder
|
96 |
+
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
att_vis_fn,
|
102 |
+
data,
|
103 |
+
outdir,
|
104 |
+
converter,
|
105 |
+
transform,
|
106 |
+
device,
|
107 |
+
reverse=False,
|
108 |
+
ikey="input",
|
109 |
+
iaxis=0,
|
110 |
+
okey="output",
|
111 |
+
oaxis=0,
|
112 |
+
subsampling_factor=1,
|
113 |
+
):
|
114 |
+
self.att_vis_fn = att_vis_fn
|
115 |
+
self.data = copy.deepcopy(data)
|
116 |
+
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
|
117 |
+
# key is utterance ID
|
118 |
+
self.outdir = outdir
|
119 |
+
self.converter = converter
|
120 |
+
self.transform = transform
|
121 |
+
self.device = device
|
122 |
+
self.reverse = reverse
|
123 |
+
self.ikey = ikey
|
124 |
+
self.iaxis = iaxis
|
125 |
+
self.okey = okey
|
126 |
+
self.oaxis = oaxis
|
127 |
+
self.factor = subsampling_factor
|
128 |
+
if not os.path.exists(self.outdir):
|
129 |
+
os.makedirs(self.outdir)
|
130 |
+
|
131 |
+
def __call__(self, trainer):
|
132 |
+
"""Plot and save image file of att_ws matrix."""
|
133 |
+
att_ws, uttid_list = self.get_attention_weights()
|
134 |
+
if isinstance(att_ws, list): # multi-encoder case
|
135 |
+
num_encs = len(att_ws) - 1
|
136 |
+
# atts
|
137 |
+
for i in range(num_encs):
|
138 |
+
for idx, att_w in enumerate(att_ws[i]):
|
139 |
+
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
|
140 |
+
self.outdir,
|
141 |
+
uttid_list[idx],
|
142 |
+
i + 1,
|
143 |
+
)
|
144 |
+
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
145 |
+
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
|
146 |
+
self.outdir,
|
147 |
+
uttid_list[idx],
|
148 |
+
i + 1,
|
149 |
+
)
|
150 |
+
np.save(np_filename.format(trainer), att_w)
|
151 |
+
self._plot_and_save_attention(att_w, filename.format(trainer))
|
152 |
+
# han
|
153 |
+
for idx, att_w in enumerate(att_ws[num_encs]):
|
154 |
+
filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
|
155 |
+
self.outdir,
|
156 |
+
uttid_list[idx],
|
157 |
+
)
|
158 |
+
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
159 |
+
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
|
160 |
+
self.outdir,
|
161 |
+
uttid_list[idx],
|
162 |
+
)
|
163 |
+
np.save(np_filename.format(trainer), att_w)
|
164 |
+
self._plot_and_save_attention(
|
165 |
+
att_w, filename.format(trainer), han_mode=True
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
for idx, att_w in enumerate(att_ws):
|
169 |
+
filename = "%s/%s.ep.{.updater.epoch}.png" % (
|
170 |
+
self.outdir,
|
171 |
+
uttid_list[idx],
|
172 |
+
)
|
173 |
+
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
174 |
+
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
|
175 |
+
self.outdir,
|
176 |
+
uttid_list[idx],
|
177 |
+
)
|
178 |
+
np.save(np_filename.format(trainer), att_w)
|
179 |
+
self._plot_and_save_attention(att_w, filename.format(trainer))
|
180 |
+
|
181 |
+
def log_attentions(self, logger, step):
|
182 |
+
"""Add image files of att_ws matrix to the tensorboard."""
|
183 |
+
att_ws, uttid_list = self.get_attention_weights()
|
184 |
+
if isinstance(att_ws, list): # multi-encoder case
|
185 |
+
num_encs = len(att_ws) - 1
|
186 |
+
# atts
|
187 |
+
for i in range(num_encs):
|
188 |
+
for idx, att_w in enumerate(att_ws[i]):
|
189 |
+
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
190 |
+
plot = self.draw_attention_plot(att_w)
|
191 |
+
logger.add_figure(
|
192 |
+
"%s_att%d" % (uttid_list[idx], i + 1),
|
193 |
+
plot.gcf(),
|
194 |
+
step,
|
195 |
+
)
|
196 |
+
# han
|
197 |
+
for idx, att_w in enumerate(att_ws[num_encs]):
|
198 |
+
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
199 |
+
plot = self.draw_han_plot(att_w)
|
200 |
+
logger.add_figure(
|
201 |
+
"%s_han" % (uttid_list[idx]),
|
202 |
+
plot.gcf(),
|
203 |
+
step,
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
for idx, att_w in enumerate(att_ws):
|
207 |
+
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
208 |
+
plot = self.draw_attention_plot(att_w)
|
209 |
+
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
|
210 |
+
|
211 |
+
def get_attention_weights(self):
|
212 |
+
"""Return attention weights.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
numpy.ndarray: attention weights. float. Its shape would be
|
216 |
+
differ from backend.
|
217 |
+
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
|
218 |
+
other case => (B, Lmax, Tmax).
|
219 |
+
* chainer-> (B, Lmax, Tmax)
|
220 |
+
|
221 |
+
"""
|
222 |
+
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
|
223 |
+
batch = self.converter([return_batch], self.device)
|
224 |
+
if isinstance(batch, tuple):
|
225 |
+
att_ws = self.att_vis_fn(*batch)
|
226 |
+
else:
|
227 |
+
att_ws = self.att_vis_fn(**batch)
|
228 |
+
return att_ws, uttid_list
|
229 |
+
|
230 |
+
def trim_attention_weight(self, uttid, att_w):
|
231 |
+
"""Transform attention matrix with regard to self.reverse."""
|
232 |
+
if self.reverse:
|
233 |
+
enc_key, enc_axis = self.okey, self.oaxis
|
234 |
+
dec_key, dec_axis = self.ikey, self.iaxis
|
235 |
+
else:
|
236 |
+
enc_key, enc_axis = self.ikey, self.iaxis
|
237 |
+
dec_key, dec_axis = self.okey, self.oaxis
|
238 |
+
dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0])
|
239 |
+
enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0])
|
240 |
+
if self.factor > 1:
|
241 |
+
enc_len //= self.factor
|
242 |
+
if len(att_w.shape) == 3:
|
243 |
+
att_w = att_w[:, :dec_len, :enc_len]
|
244 |
+
else:
|
245 |
+
att_w = att_w[:dec_len, :enc_len]
|
246 |
+
return att_w
|
247 |
+
|
248 |
+
def draw_attention_plot(self, att_w):
|
249 |
+
"""Plot the att_w matrix.
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
matplotlib.pyplot: pyplot object with attention matrix image.
|
253 |
+
|
254 |
+
"""
|
255 |
+
import matplotlib
|
256 |
+
|
257 |
+
matplotlib.use("Agg")
|
258 |
+
import matplotlib.pyplot as plt
|
259 |
+
|
260 |
+
plt.clf()
|
261 |
+
att_w = att_w.astype(np.float32)
|
262 |
+
if len(att_w.shape) == 3:
|
263 |
+
for h, aw in enumerate(att_w, 1):
|
264 |
+
plt.subplot(1, len(att_w), h)
|
265 |
+
plt.imshow(aw, aspect="auto")
|
266 |
+
plt.xlabel("Encoder Index")
|
267 |
+
plt.ylabel("Decoder Index")
|
268 |
+
else:
|
269 |
+
plt.imshow(att_w, aspect="auto")
|
270 |
+
plt.xlabel("Encoder Index")
|
271 |
+
plt.ylabel("Decoder Index")
|
272 |
+
plt.tight_layout()
|
273 |
+
return plt
|
274 |
+
|
275 |
+
def draw_han_plot(self, att_w):
|
276 |
+
"""Plot the att_w matrix for hierarchical attention.
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
matplotlib.pyplot: pyplot object with attention matrix image.
|
280 |
+
|
281 |
+
"""
|
282 |
+
import matplotlib
|
283 |
+
|
284 |
+
matplotlib.use("Agg")
|
285 |
+
import matplotlib.pyplot as plt
|
286 |
+
|
287 |
+
plt.clf()
|
288 |
+
if len(att_w.shape) == 3:
|
289 |
+
for h, aw in enumerate(att_w, 1):
|
290 |
+
legends = []
|
291 |
+
plt.subplot(1, len(att_w), h)
|
292 |
+
for i in range(aw.shape[1]):
|
293 |
+
plt.plot(aw[:, i])
|
294 |
+
legends.append("Att{}".format(i))
|
295 |
+
plt.ylim([0, 1.0])
|
296 |
+
plt.xlim([0, aw.shape[0]])
|
297 |
+
plt.grid(True)
|
298 |
+
plt.ylabel("Attention Weight")
|
299 |
+
plt.xlabel("Decoder Index")
|
300 |
+
plt.legend(legends)
|
301 |
+
else:
|
302 |
+
legends = []
|
303 |
+
for i in range(att_w.shape[1]):
|
304 |
+
plt.plot(att_w[:, i])
|
305 |
+
legends.append("Att{}".format(i))
|
306 |
+
plt.ylim([0, 1.0])
|
307 |
+
plt.xlim([0, att_w.shape[0]])
|
308 |
+
plt.grid(True)
|
309 |
+
plt.ylabel("Attention Weight")
|
310 |
+
plt.xlabel("Decoder Index")
|
311 |
+
plt.legend(legends)
|
312 |
+
plt.tight_layout()
|
313 |
+
return plt
|
314 |
+
|
315 |
+
def _plot_and_save_attention(self, att_w, filename, han_mode=False):
|
316 |
+
if han_mode:
|
317 |
+
plt = self.draw_han_plot(att_w)
|
318 |
+
else:
|
319 |
+
plt = self.draw_attention_plot(att_w)
|
320 |
+
plt.savefig(filename)
|
321 |
+
plt.close()
|
322 |
+
|
323 |
+
|
324 |
+
try:
|
325 |
+
from chainer.training import extension
|
326 |
+
except ImportError:
|
327 |
+
PlotCTCReport = None
|
328 |
+
else:
|
329 |
+
|
330 |
+
class PlotCTCReport(extension.Extension):
|
331 |
+
"""Plot CTC reporter.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
|
335 |
+
Function of CTC visualization.
|
336 |
+
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
|
337 |
+
outdir (str): Directory to save figures.
|
338 |
+
converter (espnet.asr.*_backend.asr.CustomConverter):
|
339 |
+
Function to convert data.
|
340 |
+
device (int | torch.device): Device.
|
341 |
+
reverse (bool): If True, input and output length are reversed.
|
342 |
+
ikey (str): Key to access input
|
343 |
+
(for ASR/ST ikey="input", for MT ikey="output".)
|
344 |
+
iaxis (int): Dimension to access input
|
345 |
+
(for ASR/ST iaxis=0, for MT iaxis=1.)
|
346 |
+
okey (str): Key to access output
|
347 |
+
(for ASR/ST okey="input", MT okay="output".)
|
348 |
+
oaxis (int): Dimension to access output
|
349 |
+
(for ASR/ST oaxis=0, for MT oaxis=0.)
|
350 |
+
subsampling_factor (int): subsampling factor in encoder
|
351 |
+
|
352 |
+
"""
|
353 |
+
|
354 |
+
def __init__(
|
355 |
+
self,
|
356 |
+
ctc_vis_fn,
|
357 |
+
data,
|
358 |
+
outdir,
|
359 |
+
converter,
|
360 |
+
transform,
|
361 |
+
device,
|
362 |
+
reverse=False,
|
363 |
+
ikey="input",
|
364 |
+
iaxis=0,
|
365 |
+
okey="output",
|
366 |
+
oaxis=0,
|
367 |
+
subsampling_factor=1,
|
368 |
+
):
|
369 |
+
self.ctc_vis_fn = ctc_vis_fn
|
370 |
+
self.data = copy.deepcopy(data)
|
371 |
+
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
|
372 |
+
# key is utterance ID
|
373 |
+
self.outdir = outdir
|
374 |
+
self.converter = converter
|
375 |
+
self.transform = transform
|
376 |
+
self.device = device
|
377 |
+
self.reverse = reverse
|
378 |
+
self.ikey = ikey
|
379 |
+
self.iaxis = iaxis
|
380 |
+
self.okey = okey
|
381 |
+
self.oaxis = oaxis
|
382 |
+
self.factor = subsampling_factor
|
383 |
+
if not os.path.exists(self.outdir):
|
384 |
+
os.makedirs(self.outdir)
|
385 |
+
|
386 |
+
def __call__(self, trainer):
|
387 |
+
"""Plot and save image file of ctc prob."""
|
388 |
+
ctc_probs, uttid_list = self.get_ctc_probs()
|
389 |
+
if isinstance(ctc_probs, list): # multi-encoder case
|
390 |
+
num_encs = len(ctc_probs) - 1
|
391 |
+
for i in range(num_encs):
|
392 |
+
for idx, ctc_prob in enumerate(ctc_probs[i]):
|
393 |
+
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
|
394 |
+
self.outdir,
|
395 |
+
uttid_list[idx],
|
396 |
+
i + 1,
|
397 |
+
)
|
398 |
+
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
399 |
+
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
|
400 |
+
self.outdir,
|
401 |
+
uttid_list[idx],
|
402 |
+
i + 1,
|
403 |
+
)
|
404 |
+
np.save(np_filename.format(trainer), ctc_prob)
|
405 |
+
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
|
406 |
+
else:
|
407 |
+
for idx, ctc_prob in enumerate(ctc_probs):
|
408 |
+
filename = "%s/%s.ep.{.updater.epoch}.png" % (
|
409 |
+
self.outdir,
|
410 |
+
uttid_list[idx],
|
411 |
+
)
|
412 |
+
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
413 |
+
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
|
414 |
+
self.outdir,
|
415 |
+
uttid_list[idx],
|
416 |
+
)
|
417 |
+
np.save(np_filename.format(trainer), ctc_prob)
|
418 |
+
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
|
419 |
+
|
420 |
+
def log_ctc_probs(self, logger, step):
|
421 |
+
"""Add image files of ctc probs to the tensorboard."""
|
422 |
+
ctc_probs, uttid_list = self.get_ctc_probs()
|
423 |
+
if isinstance(ctc_probs, list): # multi-encoder case
|
424 |
+
num_encs = len(ctc_probs) - 1
|
425 |
+
for i in range(num_encs):
|
426 |
+
for idx, ctc_prob in enumerate(ctc_probs[i]):
|
427 |
+
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
428 |
+
plot = self.draw_ctc_plot(ctc_prob)
|
429 |
+
logger.add_figure(
|
430 |
+
"%s_ctc%d" % (uttid_list[idx], i + 1),
|
431 |
+
plot.gcf(),
|
432 |
+
step,
|
433 |
+
)
|
434 |
+
else:
|
435 |
+
for idx, ctc_prob in enumerate(ctc_probs):
|
436 |
+
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
437 |
+
plot = self.draw_ctc_plot(ctc_prob)
|
438 |
+
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
|
439 |
+
|
440 |
+
def get_ctc_probs(self):
|
441 |
+
"""Return CTC probs.
|
442 |
+
|
443 |
+
Returns:
|
444 |
+
numpy.ndarray: CTC probs. float. Its shape would be
|
445 |
+
differ from backend. (B, Tmax, vocab).
|
446 |
+
|
447 |
+
"""
|
448 |
+
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
|
449 |
+
batch = self.converter([return_batch], self.device)
|
450 |
+
if isinstance(batch, tuple):
|
451 |
+
probs = self.ctc_vis_fn(*batch)
|
452 |
+
else:
|
453 |
+
probs = self.ctc_vis_fn(**batch)
|
454 |
+
return probs, uttid_list
|
455 |
+
|
456 |
+
def trim_ctc_prob(self, uttid, prob):
|
457 |
+
"""Trim CTC posteriors accoding to input lengths."""
|
458 |
+
enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0])
|
459 |
+
if self.factor > 1:
|
460 |
+
enc_len //= self.factor
|
461 |
+
prob = prob[:enc_len]
|
462 |
+
return prob
|
463 |
+
|
464 |
+
def draw_ctc_plot(self, ctc_prob):
|
465 |
+
"""Plot the ctc_prob matrix.
|
466 |
+
|
467 |
+
Returns:
|
468 |
+
matplotlib.pyplot: pyplot object with CTC prob matrix image.
|
469 |
+
|
470 |
+
"""
|
471 |
+
import matplotlib
|
472 |
+
|
473 |
+
matplotlib.use("Agg")
|
474 |
+
import matplotlib.pyplot as plt
|
475 |
+
|
476 |
+
ctc_prob = ctc_prob.astype(np.float32)
|
477 |
+
|
478 |
+
plt.clf()
|
479 |
+
topk_ids = np.argsort(ctc_prob, axis=1)
|
480 |
+
n_frames, vocab = ctc_prob.shape
|
481 |
+
times_probs = np.arange(n_frames)
|
482 |
+
|
483 |
+
plt.figure(figsize=(20, 8))
|
484 |
+
|
485 |
+
# NOTE: index 0 is reserved for blank
|
486 |
+
for idx in set(topk_ids.reshape(-1).tolist()):
|
487 |
+
if idx == 0:
|
488 |
+
plt.plot(
|
489 |
+
times_probs, ctc_prob[:, 0], ":", label="<blank>", color="grey"
|
490 |
+
)
|
491 |
+
else:
|
492 |
+
plt.plot(times_probs, ctc_prob[:, idx])
|
493 |
+
plt.xlabel(u"Input [frame]", fontsize=12)
|
494 |
+
plt.ylabel("Posteriors", fontsize=12)
|
495 |
+
plt.xticks(list(range(0, int(n_frames) + 1, 10)))
|
496 |
+
plt.yticks(list(range(0, 2, 1)))
|
497 |
+
plt.tight_layout()
|
498 |
+
return plt
|
499 |
+
|
500 |
+
def _plot_and_save_ctc(self, ctc_prob, filename):
|
501 |
+
plt = self.draw_ctc_plot(ctc_prob)
|
502 |
+
plt.savefig(filename)
|
503 |
+
plt.close()
|
504 |
+
|
505 |
+
|
506 |
+
def restore_snapshot(model, snapshot, load_fn=None):
|
507 |
+
"""Extension to restore snapshot.
|
508 |
+
|
509 |
+
Returns:
|
510 |
+
An extension function.
|
511 |
+
|
512 |
+
"""
|
513 |
+
import chainer
|
514 |
+
from chainer import training
|
515 |
+
|
516 |
+
if load_fn is None:
|
517 |
+
load_fn = chainer.serializers.load_npz
|
518 |
+
|
519 |
+
@training.make_extension(trigger=(1, "epoch"))
|
520 |
+
def restore_snapshot(trainer):
|
521 |
+
_restore_snapshot(model, snapshot, load_fn)
|
522 |
+
|
523 |
+
return restore_snapshot
|
524 |
+
|
525 |
+
|
526 |
+
def _restore_snapshot(model, snapshot, load_fn=None):
|
527 |
+
if load_fn is None:
|
528 |
+
import chainer
|
529 |
+
|
530 |
+
load_fn = chainer.serializers.load_npz
|
531 |
+
|
532 |
+
load_fn(snapshot, model)
|
533 |
+
logging.info("restored from " + str(snapshot))
|
534 |
+
|
535 |
+
|
536 |
+
def adadelta_eps_decay(eps_decay):
|
537 |
+
"""Extension to perform adadelta eps decay.
|
538 |
+
|
539 |
+
Args:
|
540 |
+
eps_decay (float): Decay rate of eps.
|
541 |
+
|
542 |
+
Returns:
|
543 |
+
An extension function.
|
544 |
+
|
545 |
+
"""
|
546 |
+
from chainer import training
|
547 |
+
|
548 |
+
@training.make_extension(trigger=(1, "epoch"))
|
549 |
+
def adadelta_eps_decay(trainer):
|
550 |
+
_adadelta_eps_decay(trainer, eps_decay)
|
551 |
+
|
552 |
+
return adadelta_eps_decay
|
553 |
+
|
554 |
+
|
555 |
+
def _adadelta_eps_decay(trainer, eps_decay):
|
556 |
+
optimizer = trainer.updater.get_optimizer("main")
|
557 |
+
# for chainer
|
558 |
+
if hasattr(optimizer, "eps"):
|
559 |
+
current_eps = optimizer.eps
|
560 |
+
setattr(optimizer, "eps", current_eps * eps_decay)
|
561 |
+
logging.info("adadelta eps decayed to " + str(optimizer.eps))
|
562 |
+
# pytorch
|
563 |
+
else:
|
564 |
+
for p in optimizer.param_groups:
|
565 |
+
p["eps"] *= eps_decay
|
566 |
+
logging.info("adadelta eps decayed to " + str(p["eps"]))
|
567 |
+
|
568 |
+
|
569 |
+
def adam_lr_decay(eps_decay):
|
570 |
+
"""Extension to perform adam lr decay.
|
571 |
+
|
572 |
+
Args:
|
573 |
+
eps_decay (float): Decay rate of lr.
|
574 |
+
|
575 |
+
Returns:
|
576 |
+
An extension function.
|
577 |
+
|
578 |
+
"""
|
579 |
+
from chainer import training
|
580 |
+
|
581 |
+
@training.make_extension(trigger=(1, "epoch"))
|
582 |
+
def adam_lr_decay(trainer):
|
583 |
+
_adam_lr_decay(trainer, eps_decay)
|
584 |
+
|
585 |
+
return adam_lr_decay
|
586 |
+
|
587 |
+
|
588 |
+
def _adam_lr_decay(trainer, eps_decay):
|
589 |
+
optimizer = trainer.updater.get_optimizer("main")
|
590 |
+
# for chainer
|
591 |
+
if hasattr(optimizer, "lr"):
|
592 |
+
current_lr = optimizer.lr
|
593 |
+
setattr(optimizer, "lr", current_lr * eps_decay)
|
594 |
+
logging.info("adam lr decayed to " + str(optimizer.lr))
|
595 |
+
# pytorch
|
596 |
+
else:
|
597 |
+
for p in optimizer.param_groups:
|
598 |
+
p["lr"] *= eps_decay
|
599 |
+
logging.info("adam lr decayed to " + str(p["lr"]))
|
600 |
+
|
601 |
+
|
602 |
+
def torch_snapshot(savefun=torch.save, filename="snapshot.ep.{.updater.epoch}"):
|
603 |
+
"""Extension to take snapshot of the trainer for pytorch.
|
604 |
+
|
605 |
+
Returns:
|
606 |
+
An extension function.
|
607 |
+
|
608 |
+
"""
|
609 |
+
from chainer.training import extension
|
610 |
+
|
611 |
+
@extension.make_extension(trigger=(1, "epoch"), priority=-100)
|
612 |
+
def torch_snapshot(trainer):
|
613 |
+
_torch_snapshot_object(trainer, trainer, filename.format(trainer), savefun)
|
614 |
+
|
615 |
+
return torch_snapshot
|
616 |
+
|
617 |
+
|
618 |
+
def _torch_snapshot_object(trainer, target, filename, savefun):
|
619 |
+
from chainer.serializers import DictionarySerializer
|
620 |
+
|
621 |
+
# make snapshot_dict dictionary
|
622 |
+
s = DictionarySerializer()
|
623 |
+
s.save(trainer)
|
624 |
+
if hasattr(trainer.updater.model, "model"):
|
625 |
+
# (for TTS)
|
626 |
+
if hasattr(trainer.updater.model.model, "module"):
|
627 |
+
model_state_dict = trainer.updater.model.model.module.state_dict()
|
628 |
+
else:
|
629 |
+
model_state_dict = trainer.updater.model.model.state_dict()
|
630 |
+
else:
|
631 |
+
# (for ASR)
|
632 |
+
if hasattr(trainer.updater.model, "module"):
|
633 |
+
model_state_dict = trainer.updater.model.module.state_dict()
|
634 |
+
else:
|
635 |
+
model_state_dict = trainer.updater.model.state_dict()
|
636 |
+
snapshot_dict = {
|
637 |
+
"trainer": s.target,
|
638 |
+
"model": model_state_dict,
|
639 |
+
"optimizer": trainer.updater.get_optimizer("main").state_dict(),
|
640 |
+
}
|
641 |
+
|
642 |
+
# save snapshot dictionary
|
643 |
+
fn = filename.format(trainer)
|
644 |
+
prefix = "tmp" + fn
|
645 |
+
tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out)
|
646 |
+
tmppath = os.path.join(tmpdir, fn)
|
647 |
+
try:
|
648 |
+
savefun(snapshot_dict, tmppath)
|
649 |
+
shutil.move(tmppath, os.path.join(trainer.out, fn))
|
650 |
+
finally:
|
651 |
+
shutil.rmtree(tmpdir)
|
652 |
+
|
653 |
+
|
654 |
+
def add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_factor=0.55):
|
655 |
+
"""Adds noise from a standard normal distribution to the gradients.
|
656 |
+
|
657 |
+
The standard deviation (`sigma`) is controlled by the three hyper-parameters below.
|
658 |
+
`sigma` goes to zero (no noise) with more iterations.
|
659 |
+
|
660 |
+
Args:
|
661 |
+
model (torch.nn.model): Model.
|
662 |
+
iteration (int): Number of iterations.
|
663 |
+
duration (int) {100, 1000}:
|
664 |
+
Number of durations to control the interval of the `sigma` change.
|
665 |
+
eta (float) {0.01, 0.3, 1.0}: The magnitude of `sigma`.
|
666 |
+
scale_factor (float) {0.55}: The scale of `sigma`.
|
667 |
+
"""
|
668 |
+
interval = (iteration // duration) + 1
|
669 |
+
sigma = eta / interval ** scale_factor
|
670 |
+
for param in model.parameters():
|
671 |
+
if param.grad is not None:
|
672 |
+
_shape = param.grad.size()
|
673 |
+
noise = sigma * torch.randn(_shape).to(param.device)
|
674 |
+
param.grad += noise
|
675 |
+
|
676 |
+
|
677 |
+
# * -------------------- general -------------------- *
|
678 |
+
def get_model_conf(model_path, conf_path=None):
|
679 |
+
"""Get model config information by reading a model config file (model.json).
|
680 |
+
|
681 |
+
Args:
|
682 |
+
model_path (str): Model path.
|
683 |
+
conf_path (str): Optional model config path.
|
684 |
+
|
685 |
+
Returns:
|
686 |
+
list[int, int, dict[str, Any]]: Config information loaded from json file.
|
687 |
+
|
688 |
+
"""
|
689 |
+
if conf_path is None:
|
690 |
+
model_conf = os.path.dirname(model_path) + "/model.json"
|
691 |
+
else:
|
692 |
+
model_conf = conf_path
|
693 |
+
with open(model_conf, "rb") as f:
|
694 |
+
logging.info("reading a config file from " + model_conf)
|
695 |
+
confs = json.load(f)
|
696 |
+
if isinstance(confs, dict):
|
697 |
+
# for lm
|
698 |
+
args = confs
|
699 |
+
return argparse.Namespace(**args)
|
700 |
+
else:
|
701 |
+
# for asr, tts, mt
|
702 |
+
idim, odim, args = confs
|
703 |
+
return idim, odim, argparse.Namespace(**args)
|
704 |
+
|
705 |
+
|
706 |
+
def chainer_load(path, model):
|
707 |
+
"""Load chainer model parameters.
|
708 |
+
|
709 |
+
Args:
|
710 |
+
path (str): Model path or snapshot file path to be loaded.
|
711 |
+
model (chainer.Chain): Chainer model.
|
712 |
+
|
713 |
+
"""
|
714 |
+
import chainer
|
715 |
+
|
716 |
+
if "snapshot" in os.path.basename(path):
|
717 |
+
chainer.serializers.load_npz(path, model, path="updater/model:main/")
|
718 |
+
else:
|
719 |
+
chainer.serializers.load_npz(path, model)
|
720 |
+
|
721 |
+
|
722 |
+
def torch_save(path, model):
|
723 |
+
"""Save torch model states.
|
724 |
+
|
725 |
+
Args:
|
726 |
+
path (str): Model path to be saved.
|
727 |
+
model (torch.nn.Module): Torch model.
|
728 |
+
|
729 |
+
"""
|
730 |
+
if hasattr(model, "module"):
|
731 |
+
torch.save(model.module.state_dict(), path)
|
732 |
+
else:
|
733 |
+
torch.save(model.state_dict(), path)
|
734 |
+
|
735 |
+
|
736 |
+
def snapshot_object(target, filename):
|
737 |
+
"""Returns a trainer extension to take snapshots of a given object.
|
738 |
+
|
739 |
+
Args:
|
740 |
+
target (model): Object to serialize.
|
741 |
+
filename (str): Name of the file into which the object is serialized.It can
|
742 |
+
be a format string, where the trainer object is passed to
|
743 |
+
the :meth: `str.format` method. For example,
|
744 |
+
``'snapshot_{.updater.iteration}'`` is converted to
|
745 |
+
``'snapshot_10000'`` at the 10,000th iteration.
|
746 |
+
|
747 |
+
Returns:
|
748 |
+
An extension function.
|
749 |
+
|
750 |
+
"""
|
751 |
+
from chainer.training import extension
|
752 |
+
|
753 |
+
@extension.make_extension(trigger=(1, "epoch"), priority=-100)
|
754 |
+
def snapshot_object(trainer):
|
755 |
+
torch_save(os.path.join(trainer.out, filename.format(trainer)), target)
|
756 |
+
|
757 |
+
return snapshot_object
|
758 |
+
|
759 |
+
|
760 |
+
def torch_load(path, model):
|
761 |
+
"""Load torch model states.
|
762 |
+
|
763 |
+
Args:
|
764 |
+
path (str): Model path or snapshot file path to be loaded.
|
765 |
+
model (torch.nn.Module): Torch model.
|
766 |
+
|
767 |
+
"""
|
768 |
+
if "snapshot" in os.path.basename(path):
|
769 |
+
model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)[
|
770 |
+
"model"
|
771 |
+
]
|
772 |
+
else:
|
773 |
+
model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)
|
774 |
+
|
775 |
+
if hasattr(model, "module"):
|
776 |
+
model.module.load_state_dict(model_state_dict)
|
777 |
+
else:
|
778 |
+
model.load_state_dict(model_state_dict)
|
779 |
+
|
780 |
+
del model_state_dict
|
781 |
+
|
782 |
+
|
783 |
+
def torch_resume(snapshot_path, trainer):
|
784 |
+
"""Resume from snapshot for pytorch.
|
785 |
+
|
786 |
+
Args:
|
787 |
+
snapshot_path (str): Snapshot file path.
|
788 |
+
trainer (chainer.training.Trainer): Chainer's trainer instance.
|
789 |
+
|
790 |
+
"""
|
791 |
+
from chainer.serializers import NpzDeserializer
|
792 |
+
|
793 |
+
# load snapshot
|
794 |
+
snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage)
|
795 |
+
|
796 |
+
# restore trainer states
|
797 |
+
d = NpzDeserializer(snapshot_dict["trainer"])
|
798 |
+
d.load(trainer)
|
799 |
+
|
800 |
+
# restore model states
|
801 |
+
if hasattr(trainer.updater.model, "model"):
|
802 |
+
# (for TTS model)
|
803 |
+
if hasattr(trainer.updater.model.model, "module"):
|
804 |
+
trainer.updater.model.model.module.load_state_dict(snapshot_dict["model"])
|
805 |
+
else:
|
806 |
+
trainer.updater.model.model.load_state_dict(snapshot_dict["model"])
|
807 |
+
else:
|
808 |
+
# (for ASR model)
|
809 |
+
if hasattr(trainer.updater.model, "module"):
|
810 |
+
trainer.updater.model.module.load_state_dict(snapshot_dict["model"])
|
811 |
+
else:
|
812 |
+
trainer.updater.model.load_state_dict(snapshot_dict["model"])
|
813 |
+
|
814 |
+
# retore optimizer states
|
815 |
+
trainer.updater.get_optimizer("main").load_state_dict(snapshot_dict["optimizer"])
|
816 |
+
|
817 |
+
# delete opened snapshot
|
818 |
+
del snapshot_dict
|
819 |
+
|
820 |
+
|
821 |
+
# * ------------------ recognition related ------------------ *
|
822 |
+
def parse_hypothesis(hyp, char_list):
|
823 |
+
"""Parse hypothesis.
|
824 |
+
|
825 |
+
Args:
|
826 |
+
hyp (list[dict[str, Any]]): Recognition hypothesis.
|
827 |
+
char_list (list[str]): List of characters.
|
828 |
+
|
829 |
+
Returns:
|
830 |
+
tuple(str, str, str, float)
|
831 |
+
|
832 |
+
"""
|
833 |
+
# remove sos and get results
|
834 |
+
tokenid_as_list = list(map(int, hyp["yseq"][1:]))
|
835 |
+
token_as_list = [char_list[idx] for idx in tokenid_as_list]
|
836 |
+
score = float(hyp["score"])
|
837 |
+
|
838 |
+
# convert to string
|
839 |
+
tokenid = " ".join([str(idx) for idx in tokenid_as_list])
|
840 |
+
token = " ".join(token_as_list)
|
841 |
+
text = "".join(token_as_list).replace("<space>", " ")
|
842 |
+
|
843 |
+
return text, token, tokenid, score
|
844 |
+
|
845 |
+
|
846 |
+
def add_results_to_json(js, nbest_hyps, char_list):
|
847 |
+
"""Add N-best results to json.
|
848 |
+
|
849 |
+
Args:
|
850 |
+
js (dict[str, Any]): Groundtruth utterance dict.
|
851 |
+
nbest_hyps_sd (list[dict[str, Any]]):
|
852 |
+
List of hypothesis for multi_speakers: nutts x nspkrs.
|
853 |
+
char_list (list[str]): List of characters.
|
854 |
+
|
855 |
+
Returns:
|
856 |
+
dict[str, Any]: N-best results added utterance dict.
|
857 |
+
|
858 |
+
"""
|
859 |
+
# copy old json info
|
860 |
+
new_js = dict()
|
861 |
+
new_js["utt2spk"] = js["utt2spk"]
|
862 |
+
new_js["output"] = []
|
863 |
+
|
864 |
+
for n, hyp in enumerate(nbest_hyps, 1):
|
865 |
+
# parse hypothesis
|
866 |
+
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
|
867 |
+
|
868 |
+
# copy ground-truth
|
869 |
+
if len(js["output"]) > 0:
|
870 |
+
out_dic = dict(js["output"][0].items())
|
871 |
+
else:
|
872 |
+
# for no reference case (e.g., speech translation)
|
873 |
+
out_dic = {"name": ""}
|
874 |
+
|
875 |
+
# update name
|
876 |
+
out_dic["name"] += "[%d]" % n
|
877 |
+
|
878 |
+
# add recognition results
|
879 |
+
out_dic["rec_text"] = rec_text
|
880 |
+
out_dic["rec_token"] = rec_token
|
881 |
+
out_dic["rec_tokenid"] = rec_tokenid
|
882 |
+
out_dic["score"] = score
|
883 |
+
|
884 |
+
# add to list of N-best result dicts
|
885 |
+
new_js["output"].append(out_dic)
|
886 |
+
|
887 |
+
# show 1-best result
|
888 |
+
if n == 1:
|
889 |
+
if "text" in out_dic.keys():
|
890 |
+
logging.info("groundtruth: %s" % out_dic["text"])
|
891 |
+
logging.info("prediction : %s" % out_dic["rec_text"])
|
892 |
+
|
893 |
+
return new_js
|
894 |
+
|
895 |
+
|
896 |
+
def plot_spectrogram(
|
897 |
+
plt,
|
898 |
+
spec,
|
899 |
+
mode="db",
|
900 |
+
fs=None,
|
901 |
+
frame_shift=None,
|
902 |
+
bottom=True,
|
903 |
+
left=True,
|
904 |
+
right=True,
|
905 |
+
top=False,
|
906 |
+
labelbottom=True,
|
907 |
+
labelleft=True,
|
908 |
+
labelright=True,
|
909 |
+
labeltop=False,
|
910 |
+
cmap="inferno",
|
911 |
+
):
|
912 |
+
"""Plot spectrogram using matplotlib.
|
913 |
+
|
914 |
+
Args:
|
915 |
+
plt (matplotlib.pyplot): pyplot object.
|
916 |
+
spec (numpy.ndarray): Input stft (Freq, Time)
|
917 |
+
mode (str): db or linear.
|
918 |
+
fs (int): Sample frequency. To convert y-axis to kHz unit.
|
919 |
+
frame_shift (int): The frame shift of stft. To convert x-axis to second unit.
|
920 |
+
bottom (bool):Whether to draw the respective ticks.
|
921 |
+
left (bool):
|
922 |
+
right (bool):
|
923 |
+
top (bool):
|
924 |
+
labelbottom (bool):Whether to draw the respective tick labels.
|
925 |
+
labelleft (bool):
|
926 |
+
labelright (bool):
|
927 |
+
labeltop (bool):
|
928 |
+
cmap (str): Colormap defined in matplotlib.
|
929 |
+
|
930 |
+
"""
|
931 |
+
spec = np.abs(spec)
|
932 |
+
if mode == "db":
|
933 |
+
x = 20 * np.log10(spec + np.finfo(spec.dtype).eps)
|
934 |
+
elif mode == "linear":
|
935 |
+
x = spec
|
936 |
+
else:
|
937 |
+
raise ValueError(mode)
|
938 |
+
|
939 |
+
if fs is not None:
|
940 |
+
ytop = fs / 2000
|
941 |
+
ylabel = "kHz"
|
942 |
+
else:
|
943 |
+
ytop = x.shape[0]
|
944 |
+
ylabel = "bin"
|
945 |
+
|
946 |
+
if frame_shift is not None and fs is not None:
|
947 |
+
xtop = x.shape[1] * frame_shift / fs
|
948 |
+
xlabel = "s"
|
949 |
+
else:
|
950 |
+
xtop = x.shape[1]
|
951 |
+
xlabel = "frame"
|
952 |
+
|
953 |
+
extent = (0, xtop, 0, ytop)
|
954 |
+
plt.imshow(x[::-1], cmap=cmap, extent=extent)
|
955 |
+
|
956 |
+
if labelbottom:
|
957 |
+
plt.xlabel("time [{}]".format(xlabel))
|
958 |
+
if labelleft:
|
959 |
+
plt.ylabel("freq [{}]".format(ylabel))
|
960 |
+
plt.colorbar().set_label("{}".format(mode))
|
961 |
+
|
962 |
+
plt.tick_params(
|
963 |
+
bottom=bottom,
|
964 |
+
left=left,
|
965 |
+
right=right,
|
966 |
+
top=top,
|
967 |
+
labelbottom=labelbottom,
|
968 |
+
labelleft=labelleft,
|
969 |
+
labelright=labelright,
|
970 |
+
labeltop=labeltop,
|
971 |
+
)
|
972 |
+
plt.axis("auto")
|
973 |
+
|
974 |
+
|
975 |
+
# * ------------------ recognition related ------------------ *
|
976 |
+
def format_mulenc_args(args):
|
977 |
+
"""Format args for multi-encoder setup.
|
978 |
+
|
979 |
+
It deals with following situations: (when args.num_encs=2):
|
980 |
+
1. args.elayers = None -> args.elayers = [4, 4];
|
981 |
+
2. args.elayers = 4 -> args.elayers = [4, 4];
|
982 |
+
3. args.elayers = [4, 4, 4] -> args.elayers = [4, 4].
|
983 |
+
|
984 |
+
"""
|
985 |
+
# default values when None is assigned.
|
986 |
+
default_dict = {
|
987 |
+
"etype": "blstmp",
|
988 |
+
"elayers": 4,
|
989 |
+
"eunits": 300,
|
990 |
+
"subsample": "1",
|
991 |
+
"dropout_rate": 0.0,
|
992 |
+
"atype": "dot",
|
993 |
+
"adim": 320,
|
994 |
+
"awin": 5,
|
995 |
+
"aheads": 4,
|
996 |
+
"aconv_chans": -1,
|
997 |
+
"aconv_filts": 100,
|
998 |
+
}
|
999 |
+
for k in default_dict.keys():
|
1000 |
+
if isinstance(vars(args)[k], list):
|
1001 |
+
if len(vars(args)[k]) != args.num_encs:
|
1002 |
+
logging.warning(
|
1003 |
+
"Length mismatch {}: Convert {} to {}.".format(
|
1004 |
+
k, vars(args)[k], vars(args)[k][: args.num_encs]
|
1005 |
+
)
|
1006 |
+
)
|
1007 |
+
vars(args)[k] = vars(args)[k][: args.num_encs]
|
1008 |
+
else:
|
1009 |
+
if not vars(args)[k]:
|
1010 |
+
# assign default value if it is None
|
1011 |
+
vars(args)[k] = default_dict[k]
|
1012 |
+
logging.warning(
|
1013 |
+
"{} is not specified, use default value {}.".format(
|
1014 |
+
k, default_dict[k]
|
1015 |
+
)
|
1016 |
+
)
|
1017 |
+
# duplicate
|
1018 |
+
logging.warning(
|
1019 |
+
"Type mismatch {}: Convert {} to {}.".format(
|
1020 |
+
k, vars(args)[k], [vars(args)[k] for _ in range(args.num_encs)]
|
1021 |
+
)
|
1022 |
+
)
|
1023 |
+
vars(args)[k] = [vars(args)[k] for _ in range(args.num_encs)]
|
1024 |
+
return args
|
espnet/asr/chainer_backend/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/asr/chainer_backend/asr.py
ADDED
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
|
4 |
+
"""Training/decoding definition for the speech recognition task."""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import six
|
10 |
+
|
11 |
+
# chainer related
|
12 |
+
import chainer
|
13 |
+
|
14 |
+
from chainer import training
|
15 |
+
|
16 |
+
from chainer.datasets import TransformDataset
|
17 |
+
from chainer.training import extensions
|
18 |
+
|
19 |
+
# espnet related
|
20 |
+
from espnet.asr.asr_utils import adadelta_eps_decay
|
21 |
+
from espnet.asr.asr_utils import add_results_to_json
|
22 |
+
from espnet.asr.asr_utils import chainer_load
|
23 |
+
from espnet.asr.asr_utils import CompareValueTrigger
|
24 |
+
from espnet.asr.asr_utils import get_model_conf
|
25 |
+
from espnet.asr.asr_utils import restore_snapshot
|
26 |
+
from espnet.nets.asr_interface import ASRInterface
|
27 |
+
from espnet.utils.deterministic_utils import set_deterministic_chainer
|
28 |
+
from espnet.utils.dynamic_import import dynamic_import
|
29 |
+
from espnet.utils.io_utils import LoadInputsAndTargets
|
30 |
+
from espnet.utils.training.batchfy import make_batchset
|
31 |
+
from espnet.utils.training.evaluator import BaseEvaluator
|
32 |
+
from espnet.utils.training.iterators import ShufflingEnabler
|
33 |
+
from espnet.utils.training.iterators import ToggleableShufflingMultiprocessIterator
|
34 |
+
from espnet.utils.training.iterators import ToggleableShufflingSerialIterator
|
35 |
+
from espnet.utils.training.train_utils import check_early_stop
|
36 |
+
from espnet.utils.training.train_utils import set_early_stop
|
37 |
+
|
38 |
+
# rnnlm
|
39 |
+
import espnet.lm.chainer_backend.extlm as extlm_chainer
|
40 |
+
import espnet.lm.chainer_backend.lm as lm_chainer
|
41 |
+
|
42 |
+
# numpy related
|
43 |
+
import matplotlib
|
44 |
+
|
45 |
+
from espnet.utils.training.tensorboard_logger import TensorboardLogger
|
46 |
+
from tensorboardX import SummaryWriter
|
47 |
+
|
48 |
+
matplotlib.use("Agg")
|
49 |
+
|
50 |
+
|
51 |
+
def train(args):
|
52 |
+
"""Train with the given args.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
args (namespace): The program arguments.
|
56 |
+
|
57 |
+
"""
|
58 |
+
# display chainer version
|
59 |
+
logging.info("chainer version = " + chainer.__version__)
|
60 |
+
|
61 |
+
set_deterministic_chainer(args)
|
62 |
+
|
63 |
+
# check cuda and cudnn availability
|
64 |
+
if not chainer.cuda.available:
|
65 |
+
logging.warning("cuda is not available")
|
66 |
+
if not chainer.cuda.cudnn_enabled:
|
67 |
+
logging.warning("cudnn is not available")
|
68 |
+
|
69 |
+
# get input and output dimension info
|
70 |
+
with open(args.valid_json, "rb") as f:
|
71 |
+
valid_json = json.load(f)["utts"]
|
72 |
+
utts = list(valid_json.keys())
|
73 |
+
idim = int(valid_json[utts[0]]["input"][0]["shape"][1])
|
74 |
+
odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
|
75 |
+
logging.info("#input dims : " + str(idim))
|
76 |
+
logging.info("#output dims: " + str(odim))
|
77 |
+
|
78 |
+
# specify attention, CTC, hybrid mode
|
79 |
+
if args.mtlalpha == 1.0:
|
80 |
+
mtl_mode = "ctc"
|
81 |
+
logging.info("Pure CTC mode")
|
82 |
+
elif args.mtlalpha == 0.0:
|
83 |
+
mtl_mode = "att"
|
84 |
+
logging.info("Pure attention mode")
|
85 |
+
else:
|
86 |
+
mtl_mode = "mtl"
|
87 |
+
logging.info("Multitask learning mode")
|
88 |
+
|
89 |
+
# specify model architecture
|
90 |
+
logging.info("import model module: " + args.model_module)
|
91 |
+
model_class = dynamic_import(args.model_module)
|
92 |
+
model = model_class(idim, odim, args, flag_return=False)
|
93 |
+
assert isinstance(model, ASRInterface)
|
94 |
+
total_subsampling_factor = model.get_total_subsampling_factor()
|
95 |
+
|
96 |
+
# write model config
|
97 |
+
if not os.path.exists(args.outdir):
|
98 |
+
os.makedirs(args.outdir)
|
99 |
+
model_conf = args.outdir + "/model.json"
|
100 |
+
with open(model_conf, "wb") as f:
|
101 |
+
logging.info("writing a model config file to " + model_conf)
|
102 |
+
f.write(
|
103 |
+
json.dumps(
|
104 |
+
(idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
|
105 |
+
).encode("utf_8")
|
106 |
+
)
|
107 |
+
for key in sorted(vars(args).keys()):
|
108 |
+
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
|
109 |
+
|
110 |
+
# Set gpu
|
111 |
+
ngpu = args.ngpu
|
112 |
+
if ngpu == 1:
|
113 |
+
gpu_id = 0
|
114 |
+
# Make a specified GPU current
|
115 |
+
chainer.cuda.get_device_from_id(gpu_id).use()
|
116 |
+
model.to_gpu() # Copy the model to the GPU
|
117 |
+
logging.info("single gpu calculation.")
|
118 |
+
elif ngpu > 1:
|
119 |
+
gpu_id = 0
|
120 |
+
devices = {"main": gpu_id}
|
121 |
+
for gid in six.moves.xrange(1, ngpu):
|
122 |
+
devices["sub_%d" % gid] = gid
|
123 |
+
logging.info("multi gpu calculation (#gpus = %d)." % ngpu)
|
124 |
+
logging.warning(
|
125 |
+
"batch size is automatically increased (%d -> %d)"
|
126 |
+
% (args.batch_size, args.batch_size * args.ngpu)
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
gpu_id = -1
|
130 |
+
logging.info("cpu calculation")
|
131 |
+
|
132 |
+
# Setup an optimizer
|
133 |
+
if args.opt == "adadelta":
|
134 |
+
optimizer = chainer.optimizers.AdaDelta(eps=args.eps)
|
135 |
+
elif args.opt == "adam":
|
136 |
+
optimizer = chainer.optimizers.Adam()
|
137 |
+
elif args.opt == "noam":
|
138 |
+
optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9)
|
139 |
+
else:
|
140 |
+
raise NotImplementedError("args.opt={}".format(args.opt))
|
141 |
+
|
142 |
+
optimizer.setup(model)
|
143 |
+
optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))
|
144 |
+
|
145 |
+
# Setup a converter
|
146 |
+
converter = model.custom_converter(subsampling_factor=model.subsample[0])
|
147 |
+
|
148 |
+
# read json data
|
149 |
+
with open(args.train_json, "rb") as f:
|
150 |
+
train_json = json.load(f)["utts"]
|
151 |
+
with open(args.valid_json, "rb") as f:
|
152 |
+
valid_json = json.load(f)["utts"]
|
153 |
+
|
154 |
+
# set up training iterator and updater
|
155 |
+
load_tr = LoadInputsAndTargets(
|
156 |
+
mode="asr",
|
157 |
+
load_output=True,
|
158 |
+
preprocess_conf=args.preprocess_conf,
|
159 |
+
preprocess_args={"train": True}, # Switch the mode of preprocessing
|
160 |
+
)
|
161 |
+
load_cv = LoadInputsAndTargets(
|
162 |
+
mode="asr",
|
163 |
+
load_output=True,
|
164 |
+
preprocess_conf=args.preprocess_conf,
|
165 |
+
preprocess_args={"train": False}, # Switch the mode of preprocessing
|
166 |
+
)
|
167 |
+
|
168 |
+
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
|
169 |
+
accum_grad = args.accum_grad
|
170 |
+
if ngpu <= 1:
|
171 |
+
# make minibatch list (variable length)
|
172 |
+
train = make_batchset(
|
173 |
+
train_json,
|
174 |
+
args.batch_size,
|
175 |
+
args.maxlen_in,
|
176 |
+
args.maxlen_out,
|
177 |
+
args.minibatches,
|
178 |
+
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
|
179 |
+
shortest_first=use_sortagrad,
|
180 |
+
count=args.batch_count,
|
181 |
+
batch_bins=args.batch_bins,
|
182 |
+
batch_frames_in=args.batch_frames_in,
|
183 |
+
batch_frames_out=args.batch_frames_out,
|
184 |
+
batch_frames_inout=args.batch_frames_inout,
|
185 |
+
iaxis=0,
|
186 |
+
oaxis=0,
|
187 |
+
)
|
188 |
+
# hack to make batchsize argument as 1
|
189 |
+
# actual batchsize is included in a list
|
190 |
+
if args.n_iter_processes > 0:
|
191 |
+
train_iters = [
|
192 |
+
ToggleableShufflingMultiprocessIterator(
|
193 |
+
TransformDataset(train, load_tr),
|
194 |
+
batch_size=1,
|
195 |
+
n_processes=args.n_iter_processes,
|
196 |
+
n_prefetch=8,
|
197 |
+
maxtasksperchild=20,
|
198 |
+
shuffle=not use_sortagrad,
|
199 |
+
)
|
200 |
+
]
|
201 |
+
else:
|
202 |
+
train_iters = [
|
203 |
+
ToggleableShufflingSerialIterator(
|
204 |
+
TransformDataset(train, load_tr),
|
205 |
+
batch_size=1,
|
206 |
+
shuffle=not use_sortagrad,
|
207 |
+
)
|
208 |
+
]
|
209 |
+
|
210 |
+
# set up updater
|
211 |
+
updater = model.custom_updater(
|
212 |
+
train_iters[0],
|
213 |
+
optimizer,
|
214 |
+
converter=converter,
|
215 |
+
device=gpu_id,
|
216 |
+
accum_grad=accum_grad,
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
if args.batch_count not in ("auto", "seq") and args.batch_size == 0:
|
220 |
+
raise NotImplementedError(
|
221 |
+
"--batch-count 'bin' and 'frame' are not implemented "
|
222 |
+
"in chainer multi gpu"
|
223 |
+
)
|
224 |
+
# set up minibatches
|
225 |
+
train_subsets = []
|
226 |
+
for gid in six.moves.xrange(ngpu):
|
227 |
+
# make subset
|
228 |
+
train_json_subset = {
|
229 |
+
k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid
|
230 |
+
}
|
231 |
+
# make minibatch list (variable length)
|
232 |
+
train_subsets += [
|
233 |
+
make_batchset(
|
234 |
+
train_json_subset,
|
235 |
+
args.batch_size,
|
236 |
+
args.maxlen_in,
|
237 |
+
args.maxlen_out,
|
238 |
+
args.minibatches,
|
239 |
+
)
|
240 |
+
]
|
241 |
+
|
242 |
+
# each subset must have same length for MultiprocessParallelUpdater
|
243 |
+
maxlen = max([len(train_subset) for train_subset in train_subsets])
|
244 |
+
for train_subset in train_subsets:
|
245 |
+
if maxlen != len(train_subset):
|
246 |
+
for i in six.moves.xrange(maxlen - len(train_subset)):
|
247 |
+
train_subset += [train_subset[i]]
|
248 |
+
|
249 |
+
# hack to make batchsize argument as 1
|
250 |
+
# actual batchsize is included in a list
|
251 |
+
if args.n_iter_processes > 0:
|
252 |
+
train_iters = [
|
253 |
+
ToggleableShufflingMultiprocessIterator(
|
254 |
+
TransformDataset(train_subsets[gid], load_tr),
|
255 |
+
batch_size=1,
|
256 |
+
n_processes=args.n_iter_processes,
|
257 |
+
n_prefetch=8,
|
258 |
+
maxtasksperchild=20,
|
259 |
+
shuffle=not use_sortagrad,
|
260 |
+
)
|
261 |
+
for gid in six.moves.xrange(ngpu)
|
262 |
+
]
|
263 |
+
else:
|
264 |
+
train_iters = [
|
265 |
+
ToggleableShufflingSerialIterator(
|
266 |
+
TransformDataset(train_subsets[gid], load_tr),
|
267 |
+
batch_size=1,
|
268 |
+
shuffle=not use_sortagrad,
|
269 |
+
)
|
270 |
+
for gid in six.moves.xrange(ngpu)
|
271 |
+
]
|
272 |
+
|
273 |
+
# set up updater
|
274 |
+
updater = model.custom_parallel_updater(
|
275 |
+
train_iters, optimizer, converter=converter, devices=devices
|
276 |
+
)
|
277 |
+
|
278 |
+
# Set up a trainer
|
279 |
+
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
|
280 |
+
|
281 |
+
if use_sortagrad:
|
282 |
+
trainer.extend(
|
283 |
+
ShufflingEnabler(train_iters),
|
284 |
+
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
|
285 |
+
)
|
286 |
+
if args.opt == "noam":
|
287 |
+
from espnet.nets.chainer_backend.transformer.training import VaswaniRule
|
288 |
+
|
289 |
+
trainer.extend(
|
290 |
+
VaswaniRule(
|
291 |
+
"alpha",
|
292 |
+
d=args.adim,
|
293 |
+
warmup_steps=args.transformer_warmup_steps,
|
294 |
+
scale=args.transformer_lr,
|
295 |
+
),
|
296 |
+
trigger=(1, "iteration"),
|
297 |
+
)
|
298 |
+
# Resume from a snapshot
|
299 |
+
if args.resume:
|
300 |
+
chainer.serializers.load_npz(args.resume, trainer)
|
301 |
+
|
302 |
+
# set up validation iterator
|
303 |
+
valid = make_batchset(
|
304 |
+
valid_json,
|
305 |
+
args.batch_size,
|
306 |
+
args.maxlen_in,
|
307 |
+
args.maxlen_out,
|
308 |
+
args.minibatches,
|
309 |
+
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
|
310 |
+
count=args.batch_count,
|
311 |
+
batch_bins=args.batch_bins,
|
312 |
+
batch_frames_in=args.batch_frames_in,
|
313 |
+
batch_frames_out=args.batch_frames_out,
|
314 |
+
batch_frames_inout=args.batch_frames_inout,
|
315 |
+
iaxis=0,
|
316 |
+
oaxis=0,
|
317 |
+
)
|
318 |
+
|
319 |
+
if args.n_iter_processes > 0:
|
320 |
+
valid_iter = chainer.iterators.MultiprocessIterator(
|
321 |
+
TransformDataset(valid, load_cv),
|
322 |
+
batch_size=1,
|
323 |
+
repeat=False,
|
324 |
+
shuffle=False,
|
325 |
+
n_processes=args.n_iter_processes,
|
326 |
+
n_prefetch=8,
|
327 |
+
maxtasksperchild=20,
|
328 |
+
)
|
329 |
+
else:
|
330 |
+
valid_iter = chainer.iterators.SerialIterator(
|
331 |
+
TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False
|
332 |
+
)
|
333 |
+
|
334 |
+
# Evaluate the model with the test dataset for each epoch
|
335 |
+
trainer.extend(BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id))
|
336 |
+
|
337 |
+
# Save attention weight each epoch
|
338 |
+
if args.num_save_attention > 0 and args.mtlalpha != 1.0:
|
339 |
+
data = sorted(
|
340 |
+
list(valid_json.items())[: args.num_save_attention],
|
341 |
+
key=lambda x: int(x[1]["input"][0]["shape"][1]),
|
342 |
+
reverse=True,
|
343 |
+
)
|
344 |
+
if hasattr(model, "module"):
|
345 |
+
att_vis_fn = model.module.calculate_all_attentions
|
346 |
+
plot_class = model.module.attention_plot_class
|
347 |
+
else:
|
348 |
+
att_vis_fn = model.calculate_all_attentions
|
349 |
+
plot_class = model.attention_plot_class
|
350 |
+
logging.info("Using custom PlotAttentionReport")
|
351 |
+
att_reporter = plot_class(
|
352 |
+
att_vis_fn,
|
353 |
+
data,
|
354 |
+
args.outdir + "/att_ws",
|
355 |
+
converter=converter,
|
356 |
+
transform=load_cv,
|
357 |
+
device=gpu_id,
|
358 |
+
subsampling_factor=total_subsampling_factor,
|
359 |
+
)
|
360 |
+
trainer.extend(att_reporter, trigger=(1, "epoch"))
|
361 |
+
else:
|
362 |
+
att_reporter = None
|
363 |
+
|
364 |
+
# Take a snapshot for each specified epoch
|
365 |
+
trainer.extend(
|
366 |
+
extensions.snapshot(filename="snapshot.ep.{.updater.epoch}"),
|
367 |
+
trigger=(1, "epoch"),
|
368 |
+
)
|
369 |
+
|
370 |
+
# Make a plot for training and validation values
|
371 |
+
trainer.extend(
|
372 |
+
extensions.PlotReport(
|
373 |
+
[
|
374 |
+
"main/loss",
|
375 |
+
"validation/main/loss",
|
376 |
+
"main/loss_ctc",
|
377 |
+
"validation/main/loss_ctc",
|
378 |
+
"main/loss_att",
|
379 |
+
"validation/main/loss_att",
|
380 |
+
],
|
381 |
+
"epoch",
|
382 |
+
file_name="loss.png",
|
383 |
+
)
|
384 |
+
)
|
385 |
+
trainer.extend(
|
386 |
+
extensions.PlotReport(
|
387 |
+
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
|
388 |
+
)
|
389 |
+
)
|
390 |
+
|
391 |
+
# Save best models
|
392 |
+
trainer.extend(
|
393 |
+
extensions.snapshot_object(model, "model.loss.best"),
|
394 |
+
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
|
395 |
+
)
|
396 |
+
if mtl_mode != "ctc":
|
397 |
+
trainer.extend(
|
398 |
+
extensions.snapshot_object(model, "model.acc.best"),
|
399 |
+
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
|
400 |
+
)
|
401 |
+
|
402 |
+
# epsilon decay in the optimizer
|
403 |
+
if args.opt == "adadelta":
|
404 |
+
if args.criterion == "acc" and mtl_mode != "ctc":
|
405 |
+
trainer.extend(
|
406 |
+
restore_snapshot(model, args.outdir + "/model.acc.best"),
|
407 |
+
trigger=CompareValueTrigger(
|
408 |
+
"validation/main/acc",
|
409 |
+
lambda best_value, current_value: best_value > current_value,
|
410 |
+
),
|
411 |
+
)
|
412 |
+
trainer.extend(
|
413 |
+
adadelta_eps_decay(args.eps_decay),
|
414 |
+
trigger=CompareValueTrigger(
|
415 |
+
"validation/main/acc",
|
416 |
+
lambda best_value, current_value: best_value > current_value,
|
417 |
+
),
|
418 |
+
)
|
419 |
+
elif args.criterion == "loss":
|
420 |
+
trainer.extend(
|
421 |
+
restore_snapshot(model, args.outdir + "/model.loss.best"),
|
422 |
+
trigger=CompareValueTrigger(
|
423 |
+
"validation/main/loss",
|
424 |
+
lambda best_value, current_value: best_value < current_value,
|
425 |
+
),
|
426 |
+
)
|
427 |
+
trainer.extend(
|
428 |
+
adadelta_eps_decay(args.eps_decay),
|
429 |
+
trigger=CompareValueTrigger(
|
430 |
+
"validation/main/loss",
|
431 |
+
lambda best_value, current_value: best_value < current_value,
|
432 |
+
),
|
433 |
+
)
|
434 |
+
|
435 |
+
# Write a log of evaluation statistics for each epoch
|
436 |
+
trainer.extend(
|
437 |
+
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
|
438 |
+
)
|
439 |
+
report_keys = [
|
440 |
+
"epoch",
|
441 |
+
"iteration",
|
442 |
+
"main/loss",
|
443 |
+
"main/loss_ctc",
|
444 |
+
"main/loss_att",
|
445 |
+
"validation/main/loss",
|
446 |
+
"validation/main/loss_ctc",
|
447 |
+
"validation/main/loss_att",
|
448 |
+
"main/acc",
|
449 |
+
"validation/main/acc",
|
450 |
+
"elapsed_time",
|
451 |
+
]
|
452 |
+
if args.opt == "adadelta":
|
453 |
+
trainer.extend(
|
454 |
+
extensions.observe_value(
|
455 |
+
"eps", lambda trainer: trainer.updater.get_optimizer("main").eps
|
456 |
+
),
|
457 |
+
trigger=(args.report_interval_iters, "iteration"),
|
458 |
+
)
|
459 |
+
report_keys.append("eps")
|
460 |
+
trainer.extend(
|
461 |
+
extensions.PrintReport(report_keys),
|
462 |
+
trigger=(args.report_interval_iters, "iteration"),
|
463 |
+
)
|
464 |
+
|
465 |
+
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
|
466 |
+
|
467 |
+
set_early_stop(trainer, args)
|
468 |
+
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
|
469 |
+
writer = SummaryWriter(args.tensorboard_dir)
|
470 |
+
trainer.extend(
|
471 |
+
TensorboardLogger(writer, att_reporter),
|
472 |
+
trigger=(args.report_interval_iters, "iteration"),
|
473 |
+
)
|
474 |
+
|
475 |
+
# Run the training
|
476 |
+
trainer.run()
|
477 |
+
check_early_stop(trainer, args.epochs)
|
478 |
+
|
479 |
+
|
480 |
+
def recog(args):
|
481 |
+
"""Decode with the given args.
|
482 |
+
|
483 |
+
Args:
|
484 |
+
args (namespace): The program arguments.
|
485 |
+
|
486 |
+
"""
|
487 |
+
# display chainer version
|
488 |
+
logging.info("chainer version = " + chainer.__version__)
|
489 |
+
|
490 |
+
set_deterministic_chainer(args)
|
491 |
+
|
492 |
+
# read training config
|
493 |
+
idim, odim, train_args = get_model_conf(args.model, args.model_conf)
|
494 |
+
|
495 |
+
for key in sorted(vars(args).keys()):
|
496 |
+
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
|
497 |
+
|
498 |
+
# specify model architecture
|
499 |
+
logging.info("reading model parameters from " + args.model)
|
500 |
+
# To be compatible with v.0.3.0 models
|
501 |
+
if hasattr(train_args, "model_module"):
|
502 |
+
model_module = train_args.model_module
|
503 |
+
else:
|
504 |
+
model_module = "espnet.nets.chainer_backend.e2e_asr:E2E"
|
505 |
+
model_class = dynamic_import(model_module)
|
506 |
+
model = model_class(idim, odim, train_args)
|
507 |
+
assert isinstance(model, ASRInterface)
|
508 |
+
chainer_load(args.model, model)
|
509 |
+
|
510 |
+
# read rnnlm
|
511 |
+
if args.rnnlm:
|
512 |
+
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
|
513 |
+
rnnlm = lm_chainer.ClassifierWithState(
|
514 |
+
lm_chainer.RNNLM(
|
515 |
+
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit
|
516 |
+
)
|
517 |
+
)
|
518 |
+
chainer_load(args.rnnlm, rnnlm)
|
519 |
+
else:
|
520 |
+
rnnlm = None
|
521 |
+
|
522 |
+
if args.word_rnnlm:
|
523 |
+
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
|
524 |
+
word_dict = rnnlm_args.char_list_dict
|
525 |
+
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
|
526 |
+
word_rnnlm = lm_chainer.ClassifierWithState(
|
527 |
+
lm_chainer.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)
|
528 |
+
)
|
529 |
+
chainer_load(args.word_rnnlm, word_rnnlm)
|
530 |
+
|
531 |
+
if rnnlm is not None:
|
532 |
+
rnnlm = lm_chainer.ClassifierWithState(
|
533 |
+
extlm_chainer.MultiLevelLM(
|
534 |
+
word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
|
535 |
+
)
|
536 |
+
)
|
537 |
+
else:
|
538 |
+
rnnlm = lm_chainer.ClassifierWithState(
|
539 |
+
extlm_chainer.LookAheadWordLM(
|
540 |
+
word_rnnlm.predictor, word_dict, char_dict
|
541 |
+
)
|
542 |
+
)
|
543 |
+
|
544 |
+
# read json data
|
545 |
+
with open(args.recog_json, "rb") as f:
|
546 |
+
js = json.load(f)["utts"]
|
547 |
+
|
548 |
+
load_inputs_and_targets = LoadInputsAndTargets(
|
549 |
+
mode="asr",
|
550 |
+
load_output=False,
|
551 |
+
sort_in_input_length=False,
|
552 |
+
preprocess_conf=train_args.preprocess_conf
|
553 |
+
if args.preprocess_conf is None
|
554 |
+
else args.preprocess_conf,
|
555 |
+
preprocess_args={"train": False}, # Switch the mode of preprocessing
|
556 |
+
)
|
557 |
+
|
558 |
+
# decode each utterance
|
559 |
+
new_js = {}
|
560 |
+
with chainer.no_backprop_mode():
|
561 |
+
for idx, name in enumerate(js.keys(), 1):
|
562 |
+
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
|
563 |
+
batch = [(name, js[name])]
|
564 |
+
feat = load_inputs_and_targets(batch)[0][0]
|
565 |
+
nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm)
|
566 |
+
new_js[name] = add_results_to_json(
|
567 |
+
js[name], nbest_hyps, train_args.char_list
|
568 |
+
)
|
569 |
+
|
570 |
+
with open(args.result_label, "wb") as f:
|
571 |
+
f.write(
|
572 |
+
json.dumps(
|
573 |
+
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
|
574 |
+
).encode("utf_8")
|
575 |
+
)
|
espnet/asr/pytorch_backend/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/asr/pytorch_backend/asr.py
ADDED
@@ -0,0 +1,1500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
|
4 |
+
"""Training/decoding definition for the speech recognition task."""
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import math
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
|
13 |
+
from chainer import reporter as reporter_module
|
14 |
+
from chainer import training
|
15 |
+
from chainer.training import extensions
|
16 |
+
from chainer.training.updater import StandardUpdater
|
17 |
+
import numpy as np
|
18 |
+
from tensorboardX import SummaryWriter
|
19 |
+
import torch
|
20 |
+
from torch.nn.parallel import data_parallel
|
21 |
+
|
22 |
+
from espnet.asr.asr_utils import adadelta_eps_decay
|
23 |
+
from espnet.asr.asr_utils import add_results_to_json
|
24 |
+
from espnet.asr.asr_utils import CompareValueTrigger
|
25 |
+
from espnet.asr.asr_utils import format_mulenc_args
|
26 |
+
from espnet.asr.asr_utils import get_model_conf
|
27 |
+
from espnet.asr.asr_utils import plot_spectrogram
|
28 |
+
from espnet.asr.asr_utils import restore_snapshot
|
29 |
+
from espnet.asr.asr_utils import snapshot_object
|
30 |
+
from espnet.asr.asr_utils import torch_load
|
31 |
+
from espnet.asr.asr_utils import torch_resume
|
32 |
+
from espnet.asr.asr_utils import torch_snapshot
|
33 |
+
from espnet.asr.pytorch_backend.asr_init import freeze_modules
|
34 |
+
from espnet.asr.pytorch_backend.asr_init import load_trained_model
|
35 |
+
from espnet.asr.pytorch_backend.asr_init import load_trained_modules
|
36 |
+
import espnet.lm.pytorch_backend.extlm as extlm_pytorch
|
37 |
+
from espnet.nets.asr_interface import ASRInterface
|
38 |
+
from espnet.nets.beam_search_transducer import BeamSearchTransducer
|
39 |
+
from espnet.nets.pytorch_backend.e2e_asr import pad_list
|
40 |
+
import espnet.nets.pytorch_backend.lm.default as lm_pytorch
|
41 |
+
from espnet.nets.pytorch_backend.streaming.segment import SegmentStreamingE2E
|
42 |
+
from espnet.nets.pytorch_backend.streaming.window import WindowStreamingE2E
|
43 |
+
from espnet.transform.spectrogram import IStft
|
44 |
+
from espnet.transform.transformation import Transformation
|
45 |
+
from espnet.utils.cli_writers import file_writer_helper
|
46 |
+
from espnet.utils.dataset import ChainerDataLoader
|
47 |
+
from espnet.utils.dataset import TransformDataset
|
48 |
+
from espnet.utils.deterministic_utils import set_deterministic_pytorch
|
49 |
+
from espnet.utils.dynamic_import import dynamic_import
|
50 |
+
from espnet.utils.io_utils import LoadInputsAndTargets
|
51 |
+
from espnet.utils.training.batchfy import make_batchset
|
52 |
+
from espnet.utils.training.evaluator import BaseEvaluator
|
53 |
+
from espnet.utils.training.iterators import ShufflingEnabler
|
54 |
+
from espnet.utils.training.tensorboard_logger import TensorboardLogger
|
55 |
+
from espnet.utils.training.train_utils import check_early_stop
|
56 |
+
from espnet.utils.training.train_utils import set_early_stop
|
57 |
+
|
58 |
+
import matplotlib
|
59 |
+
|
60 |
+
matplotlib.use("Agg")
|
61 |
+
|
62 |
+
if sys.version_info[0] == 2:
|
63 |
+
from itertools import izip_longest as zip_longest
|
64 |
+
else:
|
65 |
+
from itertools import zip_longest as zip_longest
|
66 |
+
|
67 |
+
|
68 |
+
def _recursive_to(xs, device):
|
69 |
+
if torch.is_tensor(xs):
|
70 |
+
return xs.to(device)
|
71 |
+
if isinstance(xs, tuple):
|
72 |
+
return tuple(_recursive_to(x, device) for x in xs)
|
73 |
+
return xs
|
74 |
+
|
75 |
+
|
76 |
+
class CustomEvaluator(BaseEvaluator):
|
77 |
+
"""Custom Evaluator for Pytorch.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
model (torch.nn.Module): The model to evaluate.
|
81 |
+
iterator (chainer.dataset.Iterator) : The train iterator.
|
82 |
+
|
83 |
+
target (link | dict[str, link]) :Link object or a dictionary of
|
84 |
+
links to evaluate. If this is just a link object, the link is
|
85 |
+
registered by the name ``'main'``.
|
86 |
+
|
87 |
+
device (torch.device): The device used.
|
88 |
+
ngpu (int): The number of GPUs.
|
89 |
+
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, model, iterator, target, device, ngpu=None):
|
93 |
+
super(CustomEvaluator, self).__init__(iterator, target)
|
94 |
+
self.model = model
|
95 |
+
self.device = device
|
96 |
+
if ngpu is not None:
|
97 |
+
self.ngpu = ngpu
|
98 |
+
elif device.type == "cpu":
|
99 |
+
self.ngpu = 0
|
100 |
+
else:
|
101 |
+
self.ngpu = 1
|
102 |
+
|
103 |
+
# The core part of the update routine can be customized by overriding
|
104 |
+
def evaluate(self):
|
105 |
+
"""Main evaluate routine for CustomEvaluator."""
|
106 |
+
iterator = self._iterators["main"]
|
107 |
+
|
108 |
+
if self.eval_hook:
|
109 |
+
self.eval_hook(self)
|
110 |
+
|
111 |
+
if hasattr(iterator, "reset"):
|
112 |
+
iterator.reset()
|
113 |
+
it = iterator
|
114 |
+
else:
|
115 |
+
it = copy.copy(iterator)
|
116 |
+
|
117 |
+
summary = reporter_module.DictSummary()
|
118 |
+
|
119 |
+
self.model.eval()
|
120 |
+
with torch.no_grad():
|
121 |
+
for batch in it:
|
122 |
+
x = _recursive_to(batch, self.device)
|
123 |
+
observation = {}
|
124 |
+
with reporter_module.report_scope(observation):
|
125 |
+
# read scp files
|
126 |
+
# x: original json with loaded features
|
127 |
+
# will be converted to chainer variable later
|
128 |
+
if self.ngpu == 0:
|
129 |
+
self.model(*x)
|
130 |
+
else:
|
131 |
+
# apex does not support torch.nn.DataParallel
|
132 |
+
data_parallel(self.model, x, range(self.ngpu))
|
133 |
+
|
134 |
+
summary.add(observation)
|
135 |
+
self.model.train()
|
136 |
+
|
137 |
+
return summary.compute_mean()
|
138 |
+
|
139 |
+
|
140 |
+
class CustomUpdater(StandardUpdater):
|
141 |
+
"""Custom Updater for Pytorch.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
model (torch.nn.Module): The model to update.
|
145 |
+
grad_clip_threshold (float): The gradient clipping value to use.
|
146 |
+
train_iter (chainer.dataset.Iterator): The training iterator.
|
147 |
+
optimizer (torch.optim.optimizer): The training optimizer.
|
148 |
+
|
149 |
+
device (torch.device): The device to use.
|
150 |
+
ngpu (int): The number of gpus to use.
|
151 |
+
use_apex (bool): The flag to use Apex in backprop.
|
152 |
+
|
153 |
+
"""
|
154 |
+
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
model,
|
158 |
+
grad_clip_threshold,
|
159 |
+
train_iter,
|
160 |
+
optimizer,
|
161 |
+
device,
|
162 |
+
ngpu,
|
163 |
+
grad_noise=False,
|
164 |
+
accum_grad=1,
|
165 |
+
use_apex=False,
|
166 |
+
):
|
167 |
+
super(CustomUpdater, self).__init__(train_iter, optimizer)
|
168 |
+
self.model = model
|
169 |
+
self.grad_clip_threshold = grad_clip_threshold
|
170 |
+
self.device = device
|
171 |
+
self.ngpu = ngpu
|
172 |
+
self.accum_grad = accum_grad
|
173 |
+
self.forward_count = 0
|
174 |
+
self.grad_noise = grad_noise
|
175 |
+
self.iteration = 0
|
176 |
+
self.use_apex = use_apex
|
177 |
+
|
178 |
+
# The core part of the update routine can be customized by overriding.
|
179 |
+
def update_core(self):
|
180 |
+
"""Main update routine of the CustomUpdater."""
|
181 |
+
# When we pass one iterator and optimizer to StandardUpdater.__init__,
|
182 |
+
# they are automatically named 'main'.
|
183 |
+
train_iter = self.get_iterator("main")
|
184 |
+
optimizer = self.get_optimizer("main")
|
185 |
+
epoch = train_iter.epoch
|
186 |
+
|
187 |
+
# Get the next batch (a list of json files)
|
188 |
+
batch = train_iter.next()
|
189 |
+
# self.iteration += 1 # Increase may result in early report,
|
190 |
+
# which is done in other place automatically.
|
191 |
+
x = _recursive_to(batch, self.device)
|
192 |
+
is_new_epoch = train_iter.epoch != epoch
|
193 |
+
# When the last minibatch in the current epoch is given,
|
194 |
+
# gradient accumulation is turned off in order to evaluate the model
|
195 |
+
# on the validation set in every epoch.
|
196 |
+
# see details in https://github.com/espnet/espnet/pull/1388
|
197 |
+
|
198 |
+
# Compute the loss at this time step and accumulate it
|
199 |
+
if self.ngpu == 0:
|
200 |
+
loss = self.model(*x).mean() / self.accum_grad
|
201 |
+
else:
|
202 |
+
# apex does not support torch.nn.DataParallel
|
203 |
+
loss = (
|
204 |
+
data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad
|
205 |
+
)
|
206 |
+
if self.use_apex:
|
207 |
+
from apex import amp
|
208 |
+
|
209 |
+
# NOTE: for a compatibility with noam optimizer
|
210 |
+
opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
|
211 |
+
with amp.scale_loss(loss, opt) as scaled_loss:
|
212 |
+
scaled_loss.backward()
|
213 |
+
else:
|
214 |
+
loss.backward()
|
215 |
+
# gradient noise injection
|
216 |
+
if self.grad_noise:
|
217 |
+
from espnet.asr.asr_utils import add_gradient_noise
|
218 |
+
|
219 |
+
add_gradient_noise(
|
220 |
+
self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55
|
221 |
+
)
|
222 |
+
|
223 |
+
# update parameters
|
224 |
+
self.forward_count += 1
|
225 |
+
if not is_new_epoch and self.forward_count != self.accum_grad:
|
226 |
+
return
|
227 |
+
self.forward_count = 0
|
228 |
+
# compute the gradient norm to check if it is normal or not
|
229 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
230 |
+
self.model.parameters(), self.grad_clip_threshold
|
231 |
+
)
|
232 |
+
logging.info("grad norm={}".format(grad_norm))
|
233 |
+
if math.isnan(grad_norm):
|
234 |
+
logging.warning("grad norm is nan. Do not update model.")
|
235 |
+
else:
|
236 |
+
optimizer.step()
|
237 |
+
optimizer.zero_grad()
|
238 |
+
|
239 |
+
def update(self):
|
240 |
+
self.update_core()
|
241 |
+
# #iterations with accum_grad > 1
|
242 |
+
# Ref.: https://github.com/espnet/espnet/issues/777
|
243 |
+
if self.forward_count == 0:
|
244 |
+
self.iteration += 1
|
245 |
+
|
246 |
+
|
247 |
+
class CustomConverter(object):
|
248 |
+
"""Custom batch converter for Pytorch.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
subsampling_factor (int): The subsampling factor.
|
252 |
+
dtype (torch.dtype): Data type to convert.
|
253 |
+
|
254 |
+
"""
|
255 |
+
|
256 |
+
def __init__(self, subsampling_factor=1, dtype=torch.float32):
|
257 |
+
"""Construct a CustomConverter object."""
|
258 |
+
self.subsampling_factor = subsampling_factor
|
259 |
+
self.ignore_id = -1
|
260 |
+
self.dtype = dtype
|
261 |
+
|
262 |
+
def __call__(self, batch, device=torch.device("cpu")):
|
263 |
+
"""Transform a batch and send it to a device.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
batch (list): The batch to transform.
|
267 |
+
device (torch.device): The device to send to.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
tuple(torch.Tensor, torch.Tensor, torch.Tensor)
|
271 |
+
|
272 |
+
"""
|
273 |
+
# batch should be located in list
|
274 |
+
assert len(batch) == 1
|
275 |
+
xs, ys = batch[0]
|
276 |
+
|
277 |
+
# perform subsampling
|
278 |
+
if self.subsampling_factor > 1:
|
279 |
+
xs = [x[:: self.subsampling_factor, :] for x in xs]
|
280 |
+
|
281 |
+
# get batch of lengths of input sequences
|
282 |
+
ilens = np.array([x.shape[0] for x in xs])
|
283 |
+
|
284 |
+
# perform padding and convert to tensor
|
285 |
+
# currently only support real number
|
286 |
+
if xs[0].dtype.kind == "c":
|
287 |
+
xs_pad_real = pad_list(
|
288 |
+
[torch.from_numpy(x.real).float() for x in xs], 0
|
289 |
+
).to(device, dtype=self.dtype)
|
290 |
+
xs_pad_imag = pad_list(
|
291 |
+
[torch.from_numpy(x.imag).float() for x in xs], 0
|
292 |
+
).to(device, dtype=self.dtype)
|
293 |
+
# Note(kamo):
|
294 |
+
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
|
295 |
+
# Don't create ComplexTensor and give it E2E here
|
296 |
+
# because torch.nn.DataParellel can't handle it.
|
297 |
+
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
|
298 |
+
else:
|
299 |
+
xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(
|
300 |
+
device, dtype=self.dtype
|
301 |
+
)
|
302 |
+
|
303 |
+
ilens = torch.from_numpy(ilens).to(device)
|
304 |
+
# NOTE: this is for multi-output (e.g., speech translation)
|
305 |
+
ys_pad = pad_list(
|
306 |
+
[
|
307 |
+
torch.from_numpy(
|
308 |
+
np.array(y[0][:]) if isinstance(y, tuple) else y
|
309 |
+
).long()
|
310 |
+
for y in ys
|
311 |
+
],
|
312 |
+
self.ignore_id,
|
313 |
+
).to(device)
|
314 |
+
|
315 |
+
return xs_pad, ilens, ys_pad
|
316 |
+
|
317 |
+
|
318 |
+
class CustomConverterMulEnc(object):
|
319 |
+
"""Custom batch converter for Pytorch in multi-encoder case.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
subsampling_factors (list): List of subsampling factors for each encoder.
|
323 |
+
dtype (torch.dtype): Data type to convert.
|
324 |
+
|
325 |
+
"""
|
326 |
+
|
327 |
+
def __init__(self, subsamping_factors=[1, 1], dtype=torch.float32):
|
328 |
+
"""Initialize the converter."""
|
329 |
+
self.subsamping_factors = subsamping_factors
|
330 |
+
self.ignore_id = -1
|
331 |
+
self.dtype = dtype
|
332 |
+
self.num_encs = len(subsamping_factors)
|
333 |
+
|
334 |
+
def __call__(self, batch, device=torch.device("cpu")):
|
335 |
+
"""Transform a batch and send it to a device.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
batch (list): The batch to transform.
|
339 |
+
device (torch.device): The device to send to.
|
340 |
+
|
341 |
+
Returns:
|
342 |
+
tuple( list(torch.Tensor), list(torch.Tensor), torch.Tensor)
|
343 |
+
|
344 |
+
"""
|
345 |
+
# batch should be located in list
|
346 |
+
assert len(batch) == 1
|
347 |
+
xs_list = batch[0][: self.num_encs]
|
348 |
+
ys = batch[0][-1]
|
349 |
+
|
350 |
+
# perform subsampling
|
351 |
+
if np.sum(self.subsamping_factors) > self.num_encs:
|
352 |
+
xs_list = [
|
353 |
+
[x[:: self.subsampling_factors[i], :] for x in xs_list[i]]
|
354 |
+
for i in range(self.num_encs)
|
355 |
+
]
|
356 |
+
|
357 |
+
# get batch of lengths of input sequences
|
358 |
+
ilens_list = [
|
359 |
+
np.array([x.shape[0] for x in xs_list[i]]) for i in range(self.num_encs)
|
360 |
+
]
|
361 |
+
|
362 |
+
# perform padding and convert to tensor
|
363 |
+
# currently only support real number
|
364 |
+
xs_list_pad = [
|
365 |
+
pad_list([torch.from_numpy(x).float() for x in xs_list[i]], 0).to(
|
366 |
+
device, dtype=self.dtype
|
367 |
+
)
|
368 |
+
for i in range(self.num_encs)
|
369 |
+
]
|
370 |
+
|
371 |
+
ilens_list = [
|
372 |
+
torch.from_numpy(ilens_list[i]).to(device) for i in range(self.num_encs)
|
373 |
+
]
|
374 |
+
# NOTE: this is for multi-task learning (e.g., speech translation)
|
375 |
+
ys_pad = pad_list(
|
376 |
+
[
|
377 |
+
torch.from_numpy(np.array(y[0]) if isinstance(y, tuple) else y).long()
|
378 |
+
for y in ys
|
379 |
+
],
|
380 |
+
self.ignore_id,
|
381 |
+
).to(device)
|
382 |
+
|
383 |
+
return xs_list_pad, ilens_list, ys_pad
|
384 |
+
|
385 |
+
|
386 |
+
def train(args):
|
387 |
+
"""Train with the given args.
|
388 |
+
|
389 |
+
Args:
|
390 |
+
args (namespace): The program arguments.
|
391 |
+
|
392 |
+
"""
|
393 |
+
set_deterministic_pytorch(args)
|
394 |
+
if args.num_encs > 1:
|
395 |
+
args = format_mulenc_args(args)
|
396 |
+
|
397 |
+
# check cuda availability
|
398 |
+
if not torch.cuda.is_available():
|
399 |
+
logging.warning("cuda is not available")
|
400 |
+
|
401 |
+
# get input and output dimension info
|
402 |
+
with open(args.valid_json, "rb") as f:
|
403 |
+
valid_json = json.load(f)["utts"]
|
404 |
+
utts = list(valid_json.keys())
|
405 |
+
idim_list = [
|
406 |
+
int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs)
|
407 |
+
]
|
408 |
+
odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
|
409 |
+
for i in range(args.num_encs):
|
410 |
+
logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i]))
|
411 |
+
logging.info("#output dims: " + str(odim))
|
412 |
+
|
413 |
+
# specify attention, CTC, hybrid mode
|
414 |
+
if "transducer" in args.model_module:
|
415 |
+
if (
|
416 |
+
getattr(args, "etype", False) == "custom"
|
417 |
+
or getattr(args, "dtype", False) == "custom"
|
418 |
+
):
|
419 |
+
mtl_mode = "custom_transducer"
|
420 |
+
else:
|
421 |
+
mtl_mode = "transducer"
|
422 |
+
logging.info("Pure transducer mode")
|
423 |
+
elif args.mtlalpha == 1.0:
|
424 |
+
mtl_mode = "ctc"
|
425 |
+
logging.info("Pure CTC mode")
|
426 |
+
elif args.mtlalpha == 0.0:
|
427 |
+
mtl_mode = "att"
|
428 |
+
logging.info("Pure attention mode")
|
429 |
+
else:
|
430 |
+
mtl_mode = "mtl"
|
431 |
+
logging.info("Multitask learning mode")
|
432 |
+
|
433 |
+
if (args.enc_init is not None or args.dec_init is not None) and args.num_encs == 1:
|
434 |
+
model = load_trained_modules(idim_list[0], odim, args)
|
435 |
+
else:
|
436 |
+
model_class = dynamic_import(args.model_module)
|
437 |
+
model = model_class(
|
438 |
+
idim_list[0] if args.num_encs == 1 else idim_list, odim, args
|
439 |
+
)
|
440 |
+
assert isinstance(model, ASRInterface)
|
441 |
+
total_subsampling_factor = model.get_total_subsampling_factor()
|
442 |
+
|
443 |
+
logging.info(
|
444 |
+
" Total parameter of the model = "
|
445 |
+
+ str(sum(p.numel() for p in model.parameters()))
|
446 |
+
)
|
447 |
+
|
448 |
+
if args.rnnlm is not None:
|
449 |
+
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
|
450 |
+
rnnlm = lm_pytorch.ClassifierWithState(
|
451 |
+
lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)
|
452 |
+
)
|
453 |
+
torch_load(args.rnnlm, rnnlm)
|
454 |
+
model.rnnlm = rnnlm
|
455 |
+
|
456 |
+
# write model config
|
457 |
+
if not os.path.exists(args.outdir):
|
458 |
+
os.makedirs(args.outdir)
|
459 |
+
model_conf = args.outdir + "/model.json"
|
460 |
+
with open(model_conf, "wb") as f:
|
461 |
+
logging.info("writing a model config file to " + model_conf)
|
462 |
+
f.write(
|
463 |
+
json.dumps(
|
464 |
+
(idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)),
|
465 |
+
indent=4,
|
466 |
+
ensure_ascii=False,
|
467 |
+
sort_keys=True,
|
468 |
+
).encode("utf_8")
|
469 |
+
)
|
470 |
+
for key in sorted(vars(args).keys()):
|
471 |
+
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
|
472 |
+
|
473 |
+
reporter = model.reporter
|
474 |
+
|
475 |
+
# check the use of multi-gpu
|
476 |
+
if args.ngpu > 1:
|
477 |
+
if args.batch_size != 0:
|
478 |
+
logging.warning(
|
479 |
+
"batch size is automatically increased (%d -> %d)"
|
480 |
+
% (args.batch_size, args.batch_size * args.ngpu)
|
481 |
+
)
|
482 |
+
args.batch_size *= args.ngpu
|
483 |
+
if args.num_encs > 1:
|
484 |
+
# TODO(ruizhili): implement data parallel for multi-encoder setup.
|
485 |
+
raise NotImplementedError(
|
486 |
+
"Data parallel is not supported for multi-encoder setup."
|
487 |
+
)
|
488 |
+
|
489 |
+
# set torch device
|
490 |
+
device = torch.device("cuda" if args.ngpu > 0 else "cpu")
|
491 |
+
if args.train_dtype in ("float16", "float32", "float64"):
|
492 |
+
dtype = getattr(torch, args.train_dtype)
|
493 |
+
else:
|
494 |
+
dtype = torch.float32
|
495 |
+
model = model.to(device=device, dtype=dtype)
|
496 |
+
|
497 |
+
if args.freeze_mods:
|
498 |
+
model, model_params = freeze_modules(model, args.freeze_mods)
|
499 |
+
else:
|
500 |
+
model_params = model.parameters()
|
501 |
+
|
502 |
+
logging.warning(
|
503 |
+
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
|
504 |
+
sum(p.numel() for p in model.parameters()),
|
505 |
+
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
506 |
+
sum(p.numel() for p in model.parameters() if p.requires_grad)
|
507 |
+
* 100.0
|
508 |
+
/ sum(p.numel() for p in model.parameters()),
|
509 |
+
)
|
510 |
+
)
|
511 |
+
|
512 |
+
# Setup an optimizer
|
513 |
+
if args.opt == "adadelta":
|
514 |
+
optimizer = torch.optim.Adadelta(
|
515 |
+
model_params, rho=0.95, eps=args.eps, weight_decay=args.weight_decay
|
516 |
+
)
|
517 |
+
elif args.opt == "adam":
|
518 |
+
optimizer = torch.optim.Adam(model_params, weight_decay=args.weight_decay)
|
519 |
+
elif args.opt == "noam":
|
520 |
+
from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
|
521 |
+
|
522 |
+
# For transformer-transducer, adim declaration is within the block definition.
|
523 |
+
# Thus, we need retrieve the most dominant value (d_hidden) for Noam scheduler.
|
524 |
+
if hasattr(args, "enc_block_arch") or hasattr(args, "dec_block_arch"):
|
525 |
+
adim = model.most_dom_dim
|
526 |
+
else:
|
527 |
+
adim = args.adim
|
528 |
+
|
529 |
+
optimizer = get_std_opt(
|
530 |
+
model_params, adim, args.transformer_warmup_steps, args.transformer_lr
|
531 |
+
)
|
532 |
+
else:
|
533 |
+
raise NotImplementedError("unknown optimizer: " + args.opt)
|
534 |
+
|
535 |
+
# setup apex.amp
|
536 |
+
if args.train_dtype in ("O0", "O1", "O2", "O3"):
|
537 |
+
try:
|
538 |
+
from apex import amp
|
539 |
+
except ImportError as e:
|
540 |
+
logging.error(
|
541 |
+
f"You need to install apex for --train-dtype {args.train_dtype}. "
|
542 |
+
"See https://github.com/NVIDIA/apex#linux"
|
543 |
+
)
|
544 |
+
raise e
|
545 |
+
if args.opt == "noam":
|
546 |
+
model, optimizer.optimizer = amp.initialize(
|
547 |
+
model, optimizer.optimizer, opt_level=args.train_dtype
|
548 |
+
)
|
549 |
+
else:
|
550 |
+
model, optimizer = amp.initialize(
|
551 |
+
model, optimizer, opt_level=args.train_dtype
|
552 |
+
)
|
553 |
+
use_apex = True
|
554 |
+
|
555 |
+
from espnet.nets.pytorch_backend.ctc import CTC
|
556 |
+
|
557 |
+
amp.register_float_function(CTC, "loss_fn")
|
558 |
+
amp.init()
|
559 |
+
logging.warning("register ctc as float function")
|
560 |
+
else:
|
561 |
+
use_apex = False
|
562 |
+
|
563 |
+
# FIXME: TOO DIRTY HACK
|
564 |
+
setattr(optimizer, "target", reporter)
|
565 |
+
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
|
566 |
+
|
567 |
+
# Setup a converter
|
568 |
+
if args.num_encs == 1:
|
569 |
+
converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype)
|
570 |
+
else:
|
571 |
+
converter = CustomConverterMulEnc(
|
572 |
+
[i[0] for i in model.subsample_list], dtype=dtype
|
573 |
+
)
|
574 |
+
|
575 |
+
# read json data
|
576 |
+
with open(args.train_json, "rb") as f:
|
577 |
+
train_json = json.load(f)["utts"]
|
578 |
+
with open(args.valid_json, "rb") as f:
|
579 |
+
valid_json = json.load(f)["utts"]
|
580 |
+
|
581 |
+
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
|
582 |
+
# make minibatch list (variable length)
|
583 |
+
train = make_batchset(
|
584 |
+
train_json,
|
585 |
+
args.batch_size,
|
586 |
+
args.maxlen_in,
|
587 |
+
args.maxlen_out,
|
588 |
+
args.minibatches,
|
589 |
+
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
|
590 |
+
shortest_first=use_sortagrad,
|
591 |
+
count=args.batch_count,
|
592 |
+
batch_bins=args.batch_bins,
|
593 |
+
batch_frames_in=args.batch_frames_in,
|
594 |
+
batch_frames_out=args.batch_frames_out,
|
595 |
+
batch_frames_inout=args.batch_frames_inout,
|
596 |
+
iaxis=0,
|
597 |
+
oaxis=0,
|
598 |
+
)
|
599 |
+
valid = make_batchset(
|
600 |
+
valid_json,
|
601 |
+
args.batch_size,
|
602 |
+
args.maxlen_in,
|
603 |
+
args.maxlen_out,
|
604 |
+
args.minibatches,
|
605 |
+
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
|
606 |
+
count=args.batch_count,
|
607 |
+
batch_bins=args.batch_bins,
|
608 |
+
batch_frames_in=args.batch_frames_in,
|
609 |
+
batch_frames_out=args.batch_frames_out,
|
610 |
+
batch_frames_inout=args.batch_frames_inout,
|
611 |
+
iaxis=0,
|
612 |
+
oaxis=0,
|
613 |
+
)
|
614 |
+
|
615 |
+
load_tr = LoadInputsAndTargets(
|
616 |
+
mode="asr",
|
617 |
+
load_output=True,
|
618 |
+
preprocess_conf=args.preprocess_conf,
|
619 |
+
preprocess_args={"train": True}, # Switch the mode of preprocessing
|
620 |
+
)
|
621 |
+
load_cv = LoadInputsAndTargets(
|
622 |
+
mode="asr",
|
623 |
+
load_output=True,
|
624 |
+
preprocess_conf=args.preprocess_conf,
|
625 |
+
preprocess_args={"train": False}, # Switch the mode of preprocessing
|
626 |
+
)
|
627 |
+
# hack to make batchsize argument as 1
|
628 |
+
# actual bathsize is included in a list
|
629 |
+
# default collate function converts numpy array to pytorch tensor
|
630 |
+
# we used an empty collate function instead which returns list
|
631 |
+
train_iter = ChainerDataLoader(
|
632 |
+
dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
|
633 |
+
batch_size=1,
|
634 |
+
num_workers=args.n_iter_processes,
|
635 |
+
shuffle=not use_sortagrad,
|
636 |
+
collate_fn=lambda x: x[0],
|
637 |
+
)
|
638 |
+
valid_iter = ChainerDataLoader(
|
639 |
+
dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
|
640 |
+
batch_size=1,
|
641 |
+
shuffle=False,
|
642 |
+
collate_fn=lambda x: x[0],
|
643 |
+
num_workers=args.n_iter_processes,
|
644 |
+
)
|
645 |
+
|
646 |
+
# Set up a trainer
|
647 |
+
updater = CustomUpdater(
|
648 |
+
model,
|
649 |
+
args.grad_clip,
|
650 |
+
{"main": train_iter},
|
651 |
+
optimizer,
|
652 |
+
device,
|
653 |
+
args.ngpu,
|
654 |
+
args.grad_noise,
|
655 |
+
args.accum_grad,
|
656 |
+
use_apex=use_apex,
|
657 |
+
)
|
658 |
+
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
|
659 |
+
|
660 |
+
if use_sortagrad:
|
661 |
+
trainer.extend(
|
662 |
+
ShufflingEnabler([train_iter]),
|
663 |
+
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
|
664 |
+
)
|
665 |
+
|
666 |
+
# Resume from a snapshot
|
667 |
+
if args.resume:
|
668 |
+
logging.info("resumed from %s" % args.resume)
|
669 |
+
torch_resume(args.resume, trainer)
|
670 |
+
|
671 |
+
# Evaluate the model with the test dataset for each epoch
|
672 |
+
if args.save_interval_iters > 0:
|
673 |
+
trainer.extend(
|
674 |
+
CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu),
|
675 |
+
trigger=(args.save_interval_iters, "iteration"),
|
676 |
+
)
|
677 |
+
else:
|
678 |
+
trainer.extend(
|
679 |
+
CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu)
|
680 |
+
)
|
681 |
+
|
682 |
+
# Save attention weight each epoch
|
683 |
+
is_attn_plot = (
|
684 |
+
"transformer" in args.model_module
|
685 |
+
or "conformer" in args.model_module
|
686 |
+
or mtl_mode in ["att", "mtl", "custom_transducer"]
|
687 |
+
)
|
688 |
+
|
689 |
+
if args.num_save_attention > 0 and is_attn_plot:
|
690 |
+
data = sorted(
|
691 |
+
list(valid_json.items())[: args.num_save_attention],
|
692 |
+
key=lambda x: int(x[1]["input"][0]["shape"][1]),
|
693 |
+
reverse=True,
|
694 |
+
)
|
695 |
+
if hasattr(model, "module"):
|
696 |
+
att_vis_fn = model.module.calculate_all_attentions
|
697 |
+
plot_class = model.module.attention_plot_class
|
698 |
+
else:
|
699 |
+
att_vis_fn = model.calculate_all_attentions
|
700 |
+
plot_class = model.attention_plot_class
|
701 |
+
att_reporter = plot_class(
|
702 |
+
att_vis_fn,
|
703 |
+
data,
|
704 |
+
args.outdir + "/att_ws",
|
705 |
+
converter=converter,
|
706 |
+
transform=load_cv,
|
707 |
+
device=device,
|
708 |
+
subsampling_factor=total_subsampling_factor,
|
709 |
+
)
|
710 |
+
trainer.extend(att_reporter, trigger=(1, "epoch"))
|
711 |
+
else:
|
712 |
+
att_reporter = None
|
713 |
+
|
714 |
+
# Save CTC prob at each epoch
|
715 |
+
if mtl_mode in ["ctc", "mtl"] and args.num_save_ctc > 0:
|
716 |
+
# NOTE: sort it by output lengths
|
717 |
+
data = sorted(
|
718 |
+
list(valid_json.items())[: args.num_save_ctc],
|
719 |
+
key=lambda x: int(x[1]["output"][0]["shape"][0]),
|
720 |
+
reverse=True,
|
721 |
+
)
|
722 |
+
if hasattr(model, "module"):
|
723 |
+
ctc_vis_fn = model.module.calculate_all_ctc_probs
|
724 |
+
plot_class = model.module.ctc_plot_class
|
725 |
+
else:
|
726 |
+
ctc_vis_fn = model.calculate_all_ctc_probs
|
727 |
+
plot_class = model.ctc_plot_class
|
728 |
+
ctc_reporter = plot_class(
|
729 |
+
ctc_vis_fn,
|
730 |
+
data,
|
731 |
+
args.outdir + "/ctc_prob",
|
732 |
+
converter=converter,
|
733 |
+
transform=load_cv,
|
734 |
+
device=device,
|
735 |
+
subsampling_factor=total_subsampling_factor,
|
736 |
+
)
|
737 |
+
trainer.extend(ctc_reporter, trigger=(1, "epoch"))
|
738 |
+
else:
|
739 |
+
ctc_reporter = None
|
740 |
+
|
741 |
+
# Make a plot for training and validation values
|
742 |
+
if args.num_encs > 1:
|
743 |
+
report_keys_loss_ctc = [
|
744 |
+
"main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)
|
745 |
+
] + ["validation/main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)]
|
746 |
+
report_keys_cer_ctc = [
|
747 |
+
"main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)
|
748 |
+
] + ["validation/main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)]
|
749 |
+
|
750 |
+
if hasattr(model, "is_rnnt"):
|
751 |
+
trainer.extend(
|
752 |
+
extensions.PlotReport(
|
753 |
+
[
|
754 |
+
"main/loss",
|
755 |
+
"validation/main/loss",
|
756 |
+
"main/loss_trans",
|
757 |
+
"validation/main/loss_trans",
|
758 |
+
"main/loss_ctc",
|
759 |
+
"validation/main/loss_ctc",
|
760 |
+
"main/loss_lm",
|
761 |
+
"validation/main/loss_lm",
|
762 |
+
"main/loss_aux_trans",
|
763 |
+
"validation/main/loss_aux_trans",
|
764 |
+
"main/loss_aux_symm_kl",
|
765 |
+
"validation/main/loss_aux_symm_kl",
|
766 |
+
],
|
767 |
+
"epoch",
|
768 |
+
file_name="loss.png",
|
769 |
+
)
|
770 |
+
)
|
771 |
+
else:
|
772 |
+
trainer.extend(
|
773 |
+
extensions.PlotReport(
|
774 |
+
[
|
775 |
+
"main/loss",
|
776 |
+
"validation/main/loss",
|
777 |
+
"main/loss_ctc",
|
778 |
+
"validation/main/loss_ctc",
|
779 |
+
"main/loss_att",
|
780 |
+
"validation/main/loss_att",
|
781 |
+
]
|
782 |
+
+ ([] if args.num_encs == 1 else report_keys_loss_ctc),
|
783 |
+
"epoch",
|
784 |
+
file_name="loss.png",
|
785 |
+
)
|
786 |
+
)
|
787 |
+
|
788 |
+
trainer.extend(
|
789 |
+
extensions.PlotReport(
|
790 |
+
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
|
791 |
+
)
|
792 |
+
)
|
793 |
+
trainer.extend(
|
794 |
+
extensions.PlotReport(
|
795 |
+
["main/cer_ctc", "validation/main/cer_ctc"]
|
796 |
+
+ ([] if args.num_encs == 1 else report_keys_loss_ctc),
|
797 |
+
"epoch",
|
798 |
+
file_name="cer.png",
|
799 |
+
)
|
800 |
+
)
|
801 |
+
|
802 |
+
# Save best models
|
803 |
+
trainer.extend(
|
804 |
+
snapshot_object(model, "model.loss.best"),
|
805 |
+
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
|
806 |
+
)
|
807 |
+
if mtl_mode not in ["ctc", "transducer", "custom_transducer"]:
|
808 |
+
trainer.extend(
|
809 |
+
snapshot_object(model, "model.acc.best"),
|
810 |
+
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
|
811 |
+
)
|
812 |
+
|
813 |
+
# save snapshot which contains model and optimizer states
|
814 |
+
if args.save_interval_iters > 0:
|
815 |
+
trainer.extend(
|
816 |
+
torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
|
817 |
+
trigger=(args.save_interval_iters, "iteration"),
|
818 |
+
)
|
819 |
+
|
820 |
+
# save snapshot at every epoch - for model averaging
|
821 |
+
trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
|
822 |
+
|
823 |
+
# epsilon decay in the optimizer
|
824 |
+
if args.opt == "adadelta":
|
825 |
+
if args.criterion == "acc" and mtl_mode != "ctc":
|
826 |
+
trainer.extend(
|
827 |
+
restore_snapshot(
|
828 |
+
model, args.outdir + "/model.acc.best", load_fn=torch_load
|
829 |
+
),
|
830 |
+
trigger=CompareValueTrigger(
|
831 |
+
"validation/main/acc",
|
832 |
+
lambda best_value, current_value: best_value > current_value,
|
833 |
+
),
|
834 |
+
)
|
835 |
+
trainer.extend(
|
836 |
+
adadelta_eps_decay(args.eps_decay),
|
837 |
+
trigger=CompareValueTrigger(
|
838 |
+
"validation/main/acc",
|
839 |
+
lambda best_value, current_value: best_value > current_value,
|
840 |
+
),
|
841 |
+
)
|
842 |
+
elif args.criterion == "loss":
|
843 |
+
trainer.extend(
|
844 |
+
restore_snapshot(
|
845 |
+
model, args.outdir + "/model.loss.best", load_fn=torch_load
|
846 |
+
),
|
847 |
+
trigger=CompareValueTrigger(
|
848 |
+
"validation/main/loss",
|
849 |
+
lambda best_value, current_value: best_value < current_value,
|
850 |
+
),
|
851 |
+
)
|
852 |
+
trainer.extend(
|
853 |
+
adadelta_eps_decay(args.eps_decay),
|
854 |
+
trigger=CompareValueTrigger(
|
855 |
+
"validation/main/loss",
|
856 |
+
lambda best_value, current_value: best_value < current_value,
|
857 |
+
),
|
858 |
+
)
|
859 |
+
# NOTE: In some cases, it may take more than one epoch for the model's loss
|
860 |
+
# to escape from a local minimum.
|
861 |
+
# Thus, restore_snapshot extension is not used here.
|
862 |
+
# see details in https://github.com/espnet/espnet/pull/2171
|
863 |
+
elif args.criterion == "loss_eps_decay_only":
|
864 |
+
trainer.extend(
|
865 |
+
adadelta_eps_decay(args.eps_decay),
|
866 |
+
trigger=CompareValueTrigger(
|
867 |
+
"validation/main/loss",
|
868 |
+
lambda best_value, current_value: best_value < current_value,
|
869 |
+
),
|
870 |
+
)
|
871 |
+
|
872 |
+
# Write a log of evaluation statistics for each epoch
|
873 |
+
trainer.extend(
|
874 |
+
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
|
875 |
+
)
|
876 |
+
|
877 |
+
if hasattr(model, "is_rnnt"):
|
878 |
+
report_keys = [
|
879 |
+
"epoch",
|
880 |
+
"iteration",
|
881 |
+
"main/loss",
|
882 |
+
"main/loss_trans",
|
883 |
+
"main/loss_ctc",
|
884 |
+
"main/loss_lm",
|
885 |
+
"main/loss_aux_trans",
|
886 |
+
"main/loss_aux_symm_kl",
|
887 |
+
"validation/main/loss",
|
888 |
+
"validation/main/loss_trans",
|
889 |
+
"validation/main/loss_ctc",
|
890 |
+
"validation/main/loss_lm",
|
891 |
+
"validation/main/loss_aux_trans",
|
892 |
+
"validation/main/loss_aux_symm_kl",
|
893 |
+
"elapsed_time",
|
894 |
+
]
|
895 |
+
else:
|
896 |
+
report_keys = [
|
897 |
+
"epoch",
|
898 |
+
"iteration",
|
899 |
+
"main/loss",
|
900 |
+
"main/loss_ctc",
|
901 |
+
"main/loss_att",
|
902 |
+
"validation/main/loss",
|
903 |
+
"validation/main/loss_ctc",
|
904 |
+
"validation/main/loss_att",
|
905 |
+
"main/acc",
|
906 |
+
"validation/main/acc",
|
907 |
+
"main/cer_ctc",
|
908 |
+
"validation/main/cer_ctc",
|
909 |
+
"elapsed_time",
|
910 |
+
] + ([] if args.num_encs == 1 else report_keys_cer_ctc + report_keys_loss_ctc)
|
911 |
+
|
912 |
+
if args.opt == "adadelta":
|
913 |
+
trainer.extend(
|
914 |
+
extensions.observe_value(
|
915 |
+
"eps",
|
916 |
+
lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
|
917 |
+
"eps"
|
918 |
+
],
|
919 |
+
),
|
920 |
+
trigger=(args.report_interval_iters, "iteration"),
|
921 |
+
)
|
922 |
+
report_keys.append("eps")
|
923 |
+
if args.report_cer:
|
924 |
+
report_keys.append("validation/main/cer")
|
925 |
+
if args.report_wer:
|
926 |
+
report_keys.append("validation/main/wer")
|
927 |
+
trainer.extend(
|
928 |
+
extensions.PrintReport(report_keys),
|
929 |
+
trigger=(args.report_interval_iters, "iteration"),
|
930 |
+
)
|
931 |
+
|
932 |
+
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
|
933 |
+
set_early_stop(trainer, args)
|
934 |
+
|
935 |
+
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
|
936 |
+
trainer.extend(
|
937 |
+
TensorboardLogger(
|
938 |
+
SummaryWriter(args.tensorboard_dir),
|
939 |
+
att_reporter=att_reporter,
|
940 |
+
ctc_reporter=ctc_reporter,
|
941 |
+
),
|
942 |
+
trigger=(args.report_interval_iters, "iteration"),
|
943 |
+
)
|
944 |
+
# Run the training
|
945 |
+
trainer.run()
|
946 |
+
check_early_stop(trainer, args.epochs)
|
947 |
+
|
948 |
+
|
949 |
+
def recog(args):
|
950 |
+
"""Decode with the given args.
|
951 |
+
|
952 |
+
Args:
|
953 |
+
args (namespace): The program arguments.
|
954 |
+
|
955 |
+
"""
|
956 |
+
set_deterministic_pytorch(args)
|
957 |
+
model, train_args = load_trained_model(args.model, training=False)
|
958 |
+
assert isinstance(model, ASRInterface)
|
959 |
+
model.recog_args = args
|
960 |
+
|
961 |
+
if args.streaming_mode and "transformer" in train_args.model_module:
|
962 |
+
raise NotImplementedError("streaming mode for transformer is not implemented")
|
963 |
+
logging.info(
|
964 |
+
" Total parameter of the model = "
|
965 |
+
+ str(sum(p.numel() for p in model.parameters()))
|
966 |
+
)
|
967 |
+
|
968 |
+
# read rnnlm
|
969 |
+
if args.rnnlm:
|
970 |
+
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
|
971 |
+
if getattr(rnnlm_args, "model_module", "default") != "default":
|
972 |
+
raise ValueError(
|
973 |
+
"use '--api v2' option to decode with non-default language model"
|
974 |
+
)
|
975 |
+
rnnlm = lm_pytorch.ClassifierWithState(
|
976 |
+
lm_pytorch.RNNLM(
|
977 |
+
len(train_args.char_list),
|
978 |
+
rnnlm_args.layer,
|
979 |
+
rnnlm_args.unit,
|
980 |
+
getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
|
981 |
+
)
|
982 |
+
)
|
983 |
+
torch_load(args.rnnlm, rnnlm)
|
984 |
+
rnnlm.eval()
|
985 |
+
else:
|
986 |
+
rnnlm = None
|
987 |
+
|
988 |
+
if args.word_rnnlm:
|
989 |
+
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
|
990 |
+
word_dict = rnnlm_args.char_list_dict
|
991 |
+
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
|
992 |
+
word_rnnlm = lm_pytorch.ClassifierWithState(
|
993 |
+
lm_pytorch.RNNLM(
|
994 |
+
len(word_dict),
|
995 |
+
rnnlm_args.layer,
|
996 |
+
rnnlm_args.unit,
|
997 |
+
getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
|
998 |
+
)
|
999 |
+
)
|
1000 |
+
torch_load(args.word_rnnlm, word_rnnlm)
|
1001 |
+
word_rnnlm.eval()
|
1002 |
+
|
1003 |
+
if rnnlm is not None:
|
1004 |
+
rnnlm = lm_pytorch.ClassifierWithState(
|
1005 |
+
extlm_pytorch.MultiLevelLM(
|
1006 |
+
word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
|
1007 |
+
)
|
1008 |
+
)
|
1009 |
+
else:
|
1010 |
+
rnnlm = lm_pytorch.ClassifierWithState(
|
1011 |
+
extlm_pytorch.LookAheadWordLM(
|
1012 |
+
word_rnnlm.predictor, word_dict, char_dict
|
1013 |
+
)
|
1014 |
+
)
|
1015 |
+
|
1016 |
+
# gpu
|
1017 |
+
if args.ngpu == 1:
|
1018 |
+
gpu_id = list(range(args.ngpu))
|
1019 |
+
logging.info("gpu id: " + str(gpu_id))
|
1020 |
+
model.cuda()
|
1021 |
+
if rnnlm:
|
1022 |
+
rnnlm.cuda()
|
1023 |
+
|
1024 |
+
# read json data
|
1025 |
+
with open(args.recog_json, "rb") as f:
|
1026 |
+
js = json.load(f)["utts"]
|
1027 |
+
new_js = {}
|
1028 |
+
|
1029 |
+
load_inputs_and_targets = LoadInputsAndTargets(
|
1030 |
+
mode="asr",
|
1031 |
+
load_output=False,
|
1032 |
+
sort_in_input_length=False,
|
1033 |
+
preprocess_conf=train_args.preprocess_conf
|
1034 |
+
if args.preprocess_conf is None
|
1035 |
+
else args.preprocess_conf,
|
1036 |
+
preprocess_args={"train": False},
|
1037 |
+
)
|
1038 |
+
|
1039 |
+
# load transducer beam search
|
1040 |
+
if hasattr(model, "is_rnnt"):
|
1041 |
+
if hasattr(model, "dec"):
|
1042 |
+
trans_decoder = model.dec
|
1043 |
+
else:
|
1044 |
+
trans_decoder = model.decoder
|
1045 |
+
joint_network = model.joint_network
|
1046 |
+
|
1047 |
+
beam_search_transducer = BeamSearchTransducer(
|
1048 |
+
decoder=trans_decoder,
|
1049 |
+
joint_network=joint_network,
|
1050 |
+
beam_size=args.beam_size,
|
1051 |
+
nbest=args.nbest,
|
1052 |
+
lm=rnnlm,
|
1053 |
+
lm_weight=args.lm_weight,
|
1054 |
+
search_type=args.search_type,
|
1055 |
+
max_sym_exp=args.max_sym_exp,
|
1056 |
+
u_max=args.u_max,
|
1057 |
+
nstep=args.nstep,
|
1058 |
+
prefix_alpha=args.prefix_alpha,
|
1059 |
+
score_norm=args.score_norm,
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
if args.batchsize == 0:
|
1063 |
+
with torch.no_grad():
|
1064 |
+
for idx, name in enumerate(js.keys(), 1):
|
1065 |
+
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
|
1066 |
+
batch = [(name, js[name])]
|
1067 |
+
feat = load_inputs_and_targets(batch)
|
1068 |
+
feat = (
|
1069 |
+
feat[0][0]
|
1070 |
+
if args.num_encs == 1
|
1071 |
+
else [feat[idx][0] for idx in range(model.num_encs)]
|
1072 |
+
)
|
1073 |
+
if args.streaming_mode == "window" and args.num_encs == 1:
|
1074 |
+
logging.info(
|
1075 |
+
"Using streaming recognizer with window size %d frames",
|
1076 |
+
args.streaming_window,
|
1077 |
+
)
|
1078 |
+
se2e = WindowStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
|
1079 |
+
for i in range(0, feat.shape[0], args.streaming_window):
|
1080 |
+
logging.info(
|
1081 |
+
"Feeding frames %d - %d", i, i + args.streaming_window
|
1082 |
+
)
|
1083 |
+
se2e.accept_input(feat[i : i + args.streaming_window])
|
1084 |
+
logging.info("Running offline attention decoder")
|
1085 |
+
se2e.decode_with_attention_offline()
|
1086 |
+
logging.info("Offline attention decoder finished")
|
1087 |
+
nbest_hyps = se2e.retrieve_recognition()
|
1088 |
+
elif args.streaming_mode == "segment" and args.num_encs == 1:
|
1089 |
+
logging.info(
|
1090 |
+
"Using streaming recognizer with threshold value %d",
|
1091 |
+
args.streaming_min_blank_dur,
|
1092 |
+
)
|
1093 |
+
nbest_hyps = []
|
1094 |
+
for n in range(args.nbest):
|
1095 |
+
nbest_hyps.append({"yseq": [], "score": 0.0})
|
1096 |
+
se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
|
1097 |
+
r = np.prod(model.subsample)
|
1098 |
+
for i in range(0, feat.shape[0], r):
|
1099 |
+
hyps = se2e.accept_input(feat[i : i + r])
|
1100 |
+
if hyps is not None:
|
1101 |
+
text = "".join(
|
1102 |
+
[
|
1103 |
+
train_args.char_list[int(x)]
|
1104 |
+
for x in hyps[0]["yseq"][1:-1]
|
1105 |
+
if int(x) != -1
|
1106 |
+
]
|
1107 |
+
)
|
1108 |
+
text = text.replace(
|
1109 |
+
"\u2581", " "
|
1110 |
+
).strip() # for SentencePiece
|
1111 |
+
text = text.replace(model.space, " ")
|
1112 |
+
text = text.replace(model.blank, "")
|
1113 |
+
logging.info(text)
|
1114 |
+
for n in range(args.nbest):
|
1115 |
+
nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"])
|
1116 |
+
nbest_hyps[n]["score"] += hyps[n]["score"]
|
1117 |
+
elif hasattr(model, "is_rnnt"):
|
1118 |
+
nbest_hyps = model.recognize(feat, beam_search_transducer)
|
1119 |
+
else:
|
1120 |
+
nbest_hyps = model.recognize(
|
1121 |
+
feat, args, train_args.char_list, rnnlm
|
1122 |
+
)
|
1123 |
+
new_js[name] = add_results_to_json(
|
1124 |
+
js[name], nbest_hyps, train_args.char_list
|
1125 |
+
)
|
1126 |
+
|
1127 |
+
else:
|
1128 |
+
|
1129 |
+
def grouper(n, iterable, fillvalue=None):
|
1130 |
+
kargs = [iter(iterable)] * n
|
1131 |
+
return zip_longest(*kargs, fillvalue=fillvalue)
|
1132 |
+
|
1133 |
+
# sort data if batchsize > 1
|
1134 |
+
keys = list(js.keys())
|
1135 |
+
if args.batchsize > 1:
|
1136 |
+
feat_lens = [js[key]["input"][0]["shape"][0] for key in keys]
|
1137 |
+
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
|
1138 |
+
keys = [keys[i] for i in sorted_index]
|
1139 |
+
|
1140 |
+
with torch.no_grad():
|
1141 |
+
for names in grouper(args.batchsize, keys, None):
|
1142 |
+
names = [name for name in names if name]
|
1143 |
+
batch = [(name, js[name]) for name in names]
|
1144 |
+
feats = (
|
1145 |
+
load_inputs_and_targets(batch)[0]
|
1146 |
+
if args.num_encs == 1
|
1147 |
+
else load_inputs_and_targets(batch)
|
1148 |
+
)
|
1149 |
+
if args.streaming_mode == "window" and args.num_encs == 1:
|
1150 |
+
raise NotImplementedError
|
1151 |
+
elif args.streaming_mode == "segment" and args.num_encs == 1:
|
1152 |
+
if args.batchsize > 1:
|
1153 |
+
raise NotImplementedError
|
1154 |
+
feat = feats[0]
|
1155 |
+
nbest_hyps = []
|
1156 |
+
for n in range(args.nbest):
|
1157 |
+
nbest_hyps.append({"yseq": [], "score": 0.0})
|
1158 |
+
se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
|
1159 |
+
r = np.prod(model.subsample)
|
1160 |
+
for i in range(0, feat.shape[0], r):
|
1161 |
+
hyps = se2e.accept_input(feat[i : i + r])
|
1162 |
+
if hyps is not None:
|
1163 |
+
text = "".join(
|
1164 |
+
[
|
1165 |
+
train_args.char_list[int(x)]
|
1166 |
+
for x in hyps[0]["yseq"][1:-1]
|
1167 |
+
if int(x) != -1
|
1168 |
+
]
|
1169 |
+
)
|
1170 |
+
text = text.replace(
|
1171 |
+
"\u2581", " "
|
1172 |
+
).strip() # for SentencePiece
|
1173 |
+
text = text.replace(model.space, " ")
|
1174 |
+
text = text.replace(model.blank, "")
|
1175 |
+
logging.info(text)
|
1176 |
+
for n in range(args.nbest):
|
1177 |
+
nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"])
|
1178 |
+
nbest_hyps[n]["score"] += hyps[n]["score"]
|
1179 |
+
nbest_hyps = [nbest_hyps]
|
1180 |
+
else:
|
1181 |
+
nbest_hyps = model.recognize_batch(
|
1182 |
+
feats, args, train_args.char_list, rnnlm=rnnlm
|
1183 |
+
)
|
1184 |
+
|
1185 |
+
for i, nbest_hyp in enumerate(nbest_hyps):
|
1186 |
+
name = names[i]
|
1187 |
+
new_js[name] = add_results_to_json(
|
1188 |
+
js[name], nbest_hyp, train_args.char_list
|
1189 |
+
)
|
1190 |
+
|
1191 |
+
with open(args.result_label, "wb") as f:
|
1192 |
+
f.write(
|
1193 |
+
json.dumps(
|
1194 |
+
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
|
1195 |
+
).encode("utf_8")
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
|
1199 |
+
def enhance(args):
|
1200 |
+
"""Dumping enhanced speech and mask.
|
1201 |
+
|
1202 |
+
Args:
|
1203 |
+
args (namespace): The program arguments.
|
1204 |
+
"""
|
1205 |
+
set_deterministic_pytorch(args)
|
1206 |
+
# read training config
|
1207 |
+
idim, odim, train_args = get_model_conf(args.model, args.model_conf)
|
1208 |
+
|
1209 |
+
# TODO(ruizhili): implement enhance for multi-encoder model
|
1210 |
+
assert args.num_encs == 1, "number of encoder should be 1 ({} is given)".format(
|
1211 |
+
args.num_encs
|
1212 |
+
)
|
1213 |
+
|
1214 |
+
# load trained model parameters
|
1215 |
+
logging.info("reading model parameters from " + args.model)
|
1216 |
+
model_class = dynamic_import(train_args.model_module)
|
1217 |
+
model = model_class(idim, odim, train_args)
|
1218 |
+
assert isinstance(model, ASRInterface)
|
1219 |
+
torch_load(args.model, model)
|
1220 |
+
model.recog_args = args
|
1221 |
+
|
1222 |
+
# gpu
|
1223 |
+
if args.ngpu == 1:
|
1224 |
+
gpu_id = list(range(args.ngpu))
|
1225 |
+
logging.info("gpu id: " + str(gpu_id))
|
1226 |
+
model.cuda()
|
1227 |
+
|
1228 |
+
# read json data
|
1229 |
+
with open(args.recog_json, "rb") as f:
|
1230 |
+
js = json.load(f)["utts"]
|
1231 |
+
|
1232 |
+
load_inputs_and_targets = LoadInputsAndTargets(
|
1233 |
+
mode="asr",
|
1234 |
+
load_output=False,
|
1235 |
+
sort_in_input_length=False,
|
1236 |
+
preprocess_conf=None, # Apply pre_process in outer func
|
1237 |
+
)
|
1238 |
+
if args.batchsize == 0:
|
1239 |
+
args.batchsize = 1
|
1240 |
+
|
1241 |
+
# Creates writers for outputs from the network
|
1242 |
+
if args.enh_wspecifier is not None:
|
1243 |
+
enh_writer = file_writer_helper(args.enh_wspecifier, filetype=args.enh_filetype)
|
1244 |
+
else:
|
1245 |
+
enh_writer = None
|
1246 |
+
|
1247 |
+
# Creates a Transformation instance
|
1248 |
+
preprocess_conf = (
|
1249 |
+
train_args.preprocess_conf
|
1250 |
+
if args.preprocess_conf is None
|
1251 |
+
else args.preprocess_conf
|
1252 |
+
)
|
1253 |
+
if preprocess_conf is not None:
|
1254 |
+
logging.info(f"Use preprocessing: {preprocess_conf}")
|
1255 |
+
transform = Transformation(preprocess_conf)
|
1256 |
+
else:
|
1257 |
+
transform = None
|
1258 |
+
|
1259 |
+
# Creates a IStft instance
|
1260 |
+
istft = None
|
1261 |
+
frame_shift = args.istft_n_shift # Used for plot the spectrogram
|
1262 |
+
if args.apply_istft:
|
1263 |
+
if preprocess_conf is not None:
|
1264 |
+
# Read the conffile and find stft setting
|
1265 |
+
with open(preprocess_conf) as f:
|
1266 |
+
# Json format: e.g.
|
1267 |
+
# {"process": [{"type": "stft",
|
1268 |
+
# "win_length": 400,
|
1269 |
+
# "n_fft": 512, "n_shift": 160,
|
1270 |
+
# "window": "han"},
|
1271 |
+
# {"type": "foo", ...}, ...]}
|
1272 |
+
conf = json.load(f)
|
1273 |
+
assert "process" in conf, conf
|
1274 |
+
# Find stft setting
|
1275 |
+
for p in conf["process"]:
|
1276 |
+
if p["type"] == "stft":
|
1277 |
+
istft = IStft(
|
1278 |
+
win_length=p["win_length"],
|
1279 |
+
n_shift=p["n_shift"],
|
1280 |
+
window=p.get("window", "hann"),
|
1281 |
+
)
|
1282 |
+
logging.info(
|
1283 |
+
"stft is found in {}. "
|
1284 |
+
"Setting istft config from it\n{}".format(
|
1285 |
+
preprocess_conf, istft
|
1286 |
+
)
|
1287 |
+
)
|
1288 |
+
frame_shift = p["n_shift"]
|
1289 |
+
break
|
1290 |
+
if istft is None:
|
1291 |
+
# Set from command line arguments
|
1292 |
+
istft = IStft(
|
1293 |
+
win_length=args.istft_win_length,
|
1294 |
+
n_shift=args.istft_n_shift,
|
1295 |
+
window=args.istft_window,
|
1296 |
+
)
|
1297 |
+
logging.info(
|
1298 |
+
"Setting istft config from the command line args\n{}".format(istft)
|
1299 |
+
)
|
1300 |
+
|
1301 |
+
# sort data
|
1302 |
+
keys = list(js.keys())
|
1303 |
+
feat_lens = [js[key]["input"][0]["shape"][0] for key in keys]
|
1304 |
+
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
|
1305 |
+
keys = [keys[i] for i in sorted_index]
|
1306 |
+
|
1307 |
+
def grouper(n, iterable, fillvalue=None):
|
1308 |
+
kargs = [iter(iterable)] * n
|
1309 |
+
return zip_longest(*kargs, fillvalue=fillvalue)
|
1310 |
+
|
1311 |
+
num_images = 0
|
1312 |
+
if not os.path.exists(args.image_dir):
|
1313 |
+
os.makedirs(args.image_dir)
|
1314 |
+
|
1315 |
+
for names in grouper(args.batchsize, keys, None):
|
1316 |
+
batch = [(name, js[name]) for name in names]
|
1317 |
+
|
1318 |
+
# May be in time region: (Batch, [Time, Channel])
|
1319 |
+
org_feats = load_inputs_and_targets(batch)[0]
|
1320 |
+
if transform is not None:
|
1321 |
+
# May be in time-freq region: : (Batch, [Time, Channel, Freq])
|
1322 |
+
feats = transform(org_feats, train=False)
|
1323 |
+
else:
|
1324 |
+
feats = org_feats
|
1325 |
+
|
1326 |
+
with torch.no_grad():
|
1327 |
+
enhanced, mask, ilens = model.enhance(feats)
|
1328 |
+
|
1329 |
+
for idx, name in enumerate(names):
|
1330 |
+
# Assuming mask, feats : [Batch, Time, Channel. Freq]
|
1331 |
+
# enhanced : [Batch, Time, Freq]
|
1332 |
+
enh = enhanced[idx][: ilens[idx]]
|
1333 |
+
mas = mask[idx][: ilens[idx]]
|
1334 |
+
feat = feats[idx]
|
1335 |
+
|
1336 |
+
# Plot spectrogram
|
1337 |
+
if args.image_dir is not None and num_images < args.num_images:
|
1338 |
+
import matplotlib.pyplot as plt
|
1339 |
+
|
1340 |
+
num_images += 1
|
1341 |
+
ref_ch = 0
|
1342 |
+
|
1343 |
+
plt.figure(figsize=(20, 10))
|
1344 |
+
plt.subplot(4, 1, 1)
|
1345 |
+
plt.title("Mask [ref={}ch]".format(ref_ch))
|
1346 |
+
plot_spectrogram(
|
1347 |
+
plt,
|
1348 |
+
mas[:, ref_ch].T,
|
1349 |
+
fs=args.fs,
|
1350 |
+
mode="linear",
|
1351 |
+
frame_shift=frame_shift,
|
1352 |
+
bottom=False,
|
1353 |
+
labelbottom=False,
|
1354 |
+
)
|
1355 |
+
|
1356 |
+
plt.subplot(4, 1, 2)
|
1357 |
+
plt.title("Noisy speech [ref={}ch]".format(ref_ch))
|
1358 |
+
plot_spectrogram(
|
1359 |
+
plt,
|
1360 |
+
feat[:, ref_ch].T,
|
1361 |
+
fs=args.fs,
|
1362 |
+
mode="db",
|
1363 |
+
frame_shift=frame_shift,
|
1364 |
+
bottom=False,
|
1365 |
+
labelbottom=False,
|
1366 |
+
)
|
1367 |
+
|
1368 |
+
plt.subplot(4, 1, 3)
|
1369 |
+
plt.title("Masked speech [ref={}ch]".format(ref_ch))
|
1370 |
+
plot_spectrogram(
|
1371 |
+
plt,
|
1372 |
+
(feat[:, ref_ch] * mas[:, ref_ch]).T,
|
1373 |
+
frame_shift=frame_shift,
|
1374 |
+
fs=args.fs,
|
1375 |
+
mode="db",
|
1376 |
+
bottom=False,
|
1377 |
+
labelbottom=False,
|
1378 |
+
)
|
1379 |
+
|
1380 |
+
plt.subplot(4, 1, 4)
|
1381 |
+
plt.title("Enhanced speech")
|
1382 |
+
plot_spectrogram(
|
1383 |
+
plt, enh.T, fs=args.fs, mode="db", frame_shift=frame_shift
|
1384 |
+
)
|
1385 |
+
|
1386 |
+
plt.savefig(os.path.join(args.image_dir, name + ".png"))
|
1387 |
+
plt.clf()
|
1388 |
+
|
1389 |
+
# Write enhanced wave files
|
1390 |
+
if enh_writer is not None:
|
1391 |
+
if istft is not None:
|
1392 |
+
enh = istft(enh)
|
1393 |
+
else:
|
1394 |
+
enh = enh
|
1395 |
+
|
1396 |
+
if args.keep_length:
|
1397 |
+
if len(org_feats[idx]) < len(enh):
|
1398 |
+
# Truncate the frames added by stft padding
|
1399 |
+
enh = enh[: len(org_feats[idx])]
|
1400 |
+
elif len(org_feats) > len(enh):
|
1401 |
+
padwidth = [(0, (len(org_feats[idx]) - len(enh)))] + [
|
1402 |
+
(0, 0)
|
1403 |
+
] * (enh.ndim - 1)
|
1404 |
+
enh = np.pad(enh, padwidth, mode="constant")
|
1405 |
+
|
1406 |
+
if args.enh_filetype in ("sound", "sound.hdf5"):
|
1407 |
+
enh_writer[name] = (args.fs, enh)
|
1408 |
+
else:
|
1409 |
+
# Hint: To dump stft_signal, mask or etc,
|
1410 |
+
# enh_filetype='hdf5' might be convenient.
|
1411 |
+
enh_writer[name] = enh
|
1412 |
+
|
1413 |
+
if num_images >= args.num_images and enh_writer is None:
|
1414 |
+
logging.info("Breaking the process.")
|
1415 |
+
break
|
1416 |
+
|
1417 |
+
|
1418 |
+
def ctc_align(args):
|
1419 |
+
"""CTC forced alignments with the given args.
|
1420 |
+
|
1421 |
+
Args:
|
1422 |
+
args (namespace): The program arguments.
|
1423 |
+
"""
|
1424 |
+
|
1425 |
+
def add_alignment_to_json(js, alignment, char_list):
|
1426 |
+
"""Add N-best results to json.
|
1427 |
+
|
1428 |
+
Args:
|
1429 |
+
js (dict[str, Any]): Groundtruth utterance dict.
|
1430 |
+
alignment (list[int]): List of alignment.
|
1431 |
+
char_list (list[str]): List of characters.
|
1432 |
+
|
1433 |
+
Returns:
|
1434 |
+
dict[str, Any]: N-best results added utterance dict.
|
1435 |
+
|
1436 |
+
"""
|
1437 |
+
# copy old json info
|
1438 |
+
new_js = dict()
|
1439 |
+
new_js["ctc_alignment"] = []
|
1440 |
+
|
1441 |
+
alignment_tokens = []
|
1442 |
+
for idx, a in enumerate(alignment):
|
1443 |
+
alignment_tokens.append(char_list[a])
|
1444 |
+
alignment_tokens = " ".join(alignment_tokens)
|
1445 |
+
|
1446 |
+
new_js["ctc_alignment"] = alignment_tokens
|
1447 |
+
|
1448 |
+
return new_js
|
1449 |
+
|
1450 |
+
set_deterministic_pytorch(args)
|
1451 |
+
model, train_args = load_trained_model(args.model)
|
1452 |
+
assert isinstance(model, ASRInterface)
|
1453 |
+
model.eval()
|
1454 |
+
|
1455 |
+
load_inputs_and_targets = LoadInputsAndTargets(
|
1456 |
+
mode="asr",
|
1457 |
+
load_output=True,
|
1458 |
+
sort_in_input_length=False,
|
1459 |
+
preprocess_conf=train_args.preprocess_conf
|
1460 |
+
if args.preprocess_conf is None
|
1461 |
+
else args.preprocess_conf,
|
1462 |
+
preprocess_args={"train": False},
|
1463 |
+
)
|
1464 |
+
|
1465 |
+
if args.ngpu > 1:
|
1466 |
+
raise NotImplementedError("only single GPU decoding is supported")
|
1467 |
+
if args.ngpu == 1:
|
1468 |
+
device = "cuda"
|
1469 |
+
else:
|
1470 |
+
device = "cpu"
|
1471 |
+
dtype = getattr(torch, args.dtype)
|
1472 |
+
logging.info(f"Decoding device={device}, dtype={dtype}")
|
1473 |
+
model.to(device=device, dtype=dtype).eval()
|
1474 |
+
|
1475 |
+
# read json data
|
1476 |
+
with open(args.align_json, "rb") as f:
|
1477 |
+
js = json.load(f)["utts"]
|
1478 |
+
new_js = {}
|
1479 |
+
if args.batchsize == 0:
|
1480 |
+
with torch.no_grad():
|
1481 |
+
for idx, name in enumerate(js.keys(), 1):
|
1482 |
+
logging.info("(%d/%d) aligning " + name, idx, len(js.keys()))
|
1483 |
+
batch = [(name, js[name])]
|
1484 |
+
feat, label = load_inputs_and_targets(batch)
|
1485 |
+
feat = feat[0]
|
1486 |
+
label = label[0]
|
1487 |
+
enc = model.encode(torch.as_tensor(feat).to(device)).unsqueeze(0)
|
1488 |
+
alignment = model.ctc.forced_align(enc, label)
|
1489 |
+
new_js[name] = add_alignment_to_json(
|
1490 |
+
js[name], alignment, train_args.char_list
|
1491 |
+
)
|
1492 |
+
else:
|
1493 |
+
raise NotImplementedError("Align_batch is not implemented.")
|
1494 |
+
|
1495 |
+
with open(args.result_label, "wb") as f:
|
1496 |
+
f.write(
|
1497 |
+
json.dumps(
|
1498 |
+
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
|
1499 |
+
).encode("utf_8")
|
1500 |
+
)
|
espnet/asr/pytorch_backend/asr_init.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Finetuning methods."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
from espnet.asr.asr_utils import get_model_conf
|
10 |
+
from espnet.asr.asr_utils import torch_load
|
11 |
+
from espnet.nets.asr_interface import ASRInterface
|
12 |
+
from espnet.nets.mt_interface import MTInterface
|
13 |
+
from espnet.nets.pytorch_backend.transducer.utils import custom_torch_load
|
14 |
+
from espnet.nets.tts_interface import TTSInterface
|
15 |
+
from espnet.utils.dynamic_import import dynamic_import
|
16 |
+
|
17 |
+
|
18 |
+
def freeze_modules(model, modules):
|
19 |
+
"""Freeze model parameters according to modules list.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
model (torch.nn.Module): main model to update
|
23 |
+
modules (list): specified module list for freezing
|
24 |
+
|
25 |
+
Return:
|
26 |
+
model (torch.nn.Module): updated model
|
27 |
+
model_params (filter): filtered model parameters
|
28 |
+
|
29 |
+
"""
|
30 |
+
for mod, param in model.named_parameters():
|
31 |
+
if any(mod.startswith(m) for m in modules):
|
32 |
+
logging.info(f"freezing {mod}, it will not be updated.")
|
33 |
+
param.requires_grad = False
|
34 |
+
|
35 |
+
model_params = filter(lambda x: x.requires_grad, model.parameters())
|
36 |
+
|
37 |
+
return model, model_params
|
38 |
+
|
39 |
+
|
40 |
+
def transfer_verification(model_state_dict, partial_state_dict, modules):
|
41 |
+
"""Verify tuples (key, shape) for input model modules match specified modules.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
model_state_dict (OrderedDict): the initial model state_dict
|
45 |
+
partial_state_dict (OrderedDict): the trained model state_dict
|
46 |
+
modules (list): specified module list for transfer
|
47 |
+
|
48 |
+
Return:
|
49 |
+
(boolean): allow transfer
|
50 |
+
|
51 |
+
"""
|
52 |
+
modules_model = []
|
53 |
+
partial_modules = []
|
54 |
+
|
55 |
+
for key_p, value_p in partial_state_dict.items():
|
56 |
+
if any(key_p.startswith(m) for m in modules):
|
57 |
+
partial_modules += [(key_p, value_p.shape)]
|
58 |
+
|
59 |
+
for key_m, value_m in model_state_dict.items():
|
60 |
+
if any(key_m.startswith(m) for m in modules):
|
61 |
+
modules_model += [(key_m, value_m.shape)]
|
62 |
+
|
63 |
+
len_match = len(modules_model) == len(partial_modules)
|
64 |
+
|
65 |
+
module_match = sorted(modules_model, key=lambda x: (x[0], x[1])) == sorted(
|
66 |
+
partial_modules, key=lambda x: (x[0], x[1])
|
67 |
+
)
|
68 |
+
|
69 |
+
return len_match and module_match
|
70 |
+
|
71 |
+
|
72 |
+
def get_partial_state_dict(model_state_dict, modules):
|
73 |
+
"""Create state_dict with specified modules matching input model modules.
|
74 |
+
|
75 |
+
Note that get_partial_lm_state_dict is used if a LM specified.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
model_state_dict (OrderedDict): trained model state_dict
|
79 |
+
modules (list): specified module list for transfer
|
80 |
+
|
81 |
+
Return:
|
82 |
+
new_state_dict (OrderedDict): the updated state_dict
|
83 |
+
|
84 |
+
"""
|
85 |
+
new_state_dict = OrderedDict()
|
86 |
+
|
87 |
+
for key, value in model_state_dict.items():
|
88 |
+
if any(key.startswith(m) for m in modules):
|
89 |
+
new_state_dict[key] = value
|
90 |
+
|
91 |
+
return new_state_dict
|
92 |
+
|
93 |
+
|
94 |
+
def get_lm_state_dict(lm_state_dict):
|
95 |
+
"""Create compatible ASR decoder state dict from LM state dict.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
lm_state_dict (OrderedDict): pre-trained LM state_dict
|
99 |
+
|
100 |
+
Return:
|
101 |
+
new_state_dict (OrderedDict): LM state_dict with updated keys
|
102 |
+
|
103 |
+
"""
|
104 |
+
new_state_dict = OrderedDict()
|
105 |
+
|
106 |
+
for key, value in list(lm_state_dict.items()):
|
107 |
+
if key == "predictor.embed.weight":
|
108 |
+
new_state_dict["dec.embed.weight"] = value
|
109 |
+
elif key.startswith("predictor.rnn."):
|
110 |
+
_split = key.split(".")
|
111 |
+
|
112 |
+
new_key = "dec.decoder." + _split[2] + "." + _split[3] + "_l0"
|
113 |
+
new_state_dict[new_key] = value
|
114 |
+
|
115 |
+
return new_state_dict
|
116 |
+
|
117 |
+
|
118 |
+
def filter_modules(model_state_dict, modules):
|
119 |
+
"""Filter non-matched modules in module_state_dict.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
model_state_dict (OrderedDict): trained model state_dict
|
123 |
+
modules (list): specified module list for transfer
|
124 |
+
|
125 |
+
Return:
|
126 |
+
new_mods (list): the update module list
|
127 |
+
|
128 |
+
"""
|
129 |
+
new_mods = []
|
130 |
+
incorrect_mods = []
|
131 |
+
|
132 |
+
mods_model = list(model_state_dict.keys())
|
133 |
+
for mod in modules:
|
134 |
+
if any(key.startswith(mod) for key in mods_model):
|
135 |
+
new_mods += [mod]
|
136 |
+
else:
|
137 |
+
incorrect_mods += [mod]
|
138 |
+
|
139 |
+
if incorrect_mods:
|
140 |
+
logging.warning(
|
141 |
+
"module(s) %s don't match or (partially match) "
|
142 |
+
"available modules in model.",
|
143 |
+
incorrect_mods,
|
144 |
+
)
|
145 |
+
logging.warning("for information, the existing modules in model are:")
|
146 |
+
logging.warning("%s", mods_model)
|
147 |
+
|
148 |
+
return new_mods
|
149 |
+
|
150 |
+
|
151 |
+
def load_trained_model(model_path, training=True):
|
152 |
+
"""Load the trained model for recognition.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
model_path (str): Path to model.***.best
|
156 |
+
|
157 |
+
"""
|
158 |
+
idim, odim, train_args = get_model_conf(
|
159 |
+
model_path, os.path.join(os.path.dirname(model_path), "model.json")
|
160 |
+
)
|
161 |
+
|
162 |
+
logging.warning("reading model parameters from " + model_path)
|
163 |
+
|
164 |
+
if hasattr(train_args, "model_module"):
|
165 |
+
model_module = train_args.model_module
|
166 |
+
else:
|
167 |
+
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
|
168 |
+
# CTC Loss is not needed, default to builtin to prevent import errors
|
169 |
+
if hasattr(train_args, "ctc_type"):
|
170 |
+
train_args.ctc_type = "builtin"
|
171 |
+
|
172 |
+
model_class = dynamic_import(model_module)
|
173 |
+
|
174 |
+
if "transducer" in model_module:
|
175 |
+
model = model_class(idim, odim, train_args, training=training)
|
176 |
+
custom_torch_load(model_path, model, training=training)
|
177 |
+
else:
|
178 |
+
model = model_class(idim, odim, train_args)
|
179 |
+
torch_load(model_path, model)
|
180 |
+
|
181 |
+
return model, train_args
|
182 |
+
|
183 |
+
|
184 |
+
def get_trained_model_state_dict(model_path):
|
185 |
+
"""Extract the trained model state dict for pre-initialization.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
model_path (str): Path to model.***.best
|
189 |
+
|
190 |
+
Return:
|
191 |
+
model.state_dict() (OrderedDict): the loaded model state_dict
|
192 |
+
(bool): Boolean defining whether the model is an LM
|
193 |
+
|
194 |
+
"""
|
195 |
+
conf_path = os.path.join(os.path.dirname(model_path), "model.json")
|
196 |
+
if "rnnlm" in model_path:
|
197 |
+
logging.warning("reading model parameters from %s", model_path)
|
198 |
+
|
199 |
+
return get_lm_state_dict(torch.load(model_path))
|
200 |
+
|
201 |
+
idim, odim, args = get_model_conf(model_path, conf_path)
|
202 |
+
|
203 |
+
logging.warning("reading model parameters from " + model_path)
|
204 |
+
|
205 |
+
if hasattr(args, "model_module"):
|
206 |
+
model_module = args.model_module
|
207 |
+
else:
|
208 |
+
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
|
209 |
+
|
210 |
+
model_class = dynamic_import(model_module)
|
211 |
+
model = model_class(idim, odim, args)
|
212 |
+
torch_load(model_path, model)
|
213 |
+
assert (
|
214 |
+
isinstance(model, MTInterface)
|
215 |
+
or isinstance(model, ASRInterface)
|
216 |
+
or isinstance(model, TTSInterface)
|
217 |
+
)
|
218 |
+
|
219 |
+
return model.state_dict()
|
220 |
+
|
221 |
+
|
222 |
+
def load_trained_modules(idim, odim, args, interface=ASRInterface):
|
223 |
+
"""Load model encoder or/and decoder modules with ESPNET pre-trained model(s).
|
224 |
+
|
225 |
+
Args:
|
226 |
+
idim (int): initial input dimension.
|
227 |
+
odim (int): initial output dimension.
|
228 |
+
args (Namespace): The initial model arguments.
|
229 |
+
interface (Interface): ASRInterface or STInterface or TTSInterface.
|
230 |
+
|
231 |
+
Return:
|
232 |
+
model (torch.nn.Module): The model with pretrained modules.
|
233 |
+
|
234 |
+
"""
|
235 |
+
|
236 |
+
def print_new_keys(state_dict, modules, model_path):
|
237 |
+
logging.warning("loading %s from model: %s", modules, model_path)
|
238 |
+
|
239 |
+
for k in state_dict.keys():
|
240 |
+
logging.warning("override %s" % k)
|
241 |
+
|
242 |
+
enc_model_path = args.enc_init
|
243 |
+
dec_model_path = args.dec_init
|
244 |
+
enc_modules = args.enc_init_mods
|
245 |
+
dec_modules = args.dec_init_mods
|
246 |
+
|
247 |
+
model_class = dynamic_import(args.model_module)
|
248 |
+
main_model = model_class(idim, odim, args)
|
249 |
+
assert isinstance(main_model, interface)
|
250 |
+
|
251 |
+
main_state_dict = main_model.state_dict()
|
252 |
+
|
253 |
+
logging.warning("model(s) found for pre-initialization")
|
254 |
+
for model_path, modules in [
|
255 |
+
(enc_model_path, enc_modules),
|
256 |
+
(dec_model_path, dec_modules),
|
257 |
+
]:
|
258 |
+
if model_path is not None:
|
259 |
+
if os.path.isfile(model_path):
|
260 |
+
model_state_dict = get_trained_model_state_dict(model_path)
|
261 |
+
|
262 |
+
modules = filter_modules(model_state_dict, modules)
|
263 |
+
|
264 |
+
partial_state_dict = get_partial_state_dict(model_state_dict, modules)
|
265 |
+
|
266 |
+
if partial_state_dict:
|
267 |
+
if transfer_verification(
|
268 |
+
main_state_dict, partial_state_dict, modules
|
269 |
+
):
|
270 |
+
print_new_keys(partial_state_dict, modules, model_path)
|
271 |
+
main_state_dict.update(partial_state_dict)
|
272 |
+
else:
|
273 |
+
logging.warning(
|
274 |
+
f"modules {modules} in model {model_path} "
|
275 |
+
f"don't match your training config",
|
276 |
+
)
|
277 |
+
else:
|
278 |
+
logging.warning("model was not found : %s", model_path)
|
279 |
+
|
280 |
+
main_model.load_state_dict(main_state_dict)
|
281 |
+
|
282 |
+
return main_model
|
espnet/asr/pytorch_backend/asr_mix.py
ADDED
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
This script is used for multi-speaker speech recognition.
|
5 |
+
|
6 |
+
Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
7 |
+
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
8 |
+
"""
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
|
13 |
+
# chainer related
|
14 |
+
from chainer import training
|
15 |
+
from chainer.training import extensions
|
16 |
+
from itertools import zip_longest as zip_longest
|
17 |
+
import numpy as np
|
18 |
+
from tensorboardX import SummaryWriter
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from espnet.asr.asr_mix_utils import add_results_to_json
|
22 |
+
from espnet.asr.asr_utils import adadelta_eps_decay
|
23 |
+
|
24 |
+
from espnet.asr.asr_utils import CompareValueTrigger
|
25 |
+
from espnet.asr.asr_utils import get_model_conf
|
26 |
+
from espnet.asr.asr_utils import restore_snapshot
|
27 |
+
from espnet.asr.asr_utils import snapshot_object
|
28 |
+
from espnet.asr.asr_utils import torch_load
|
29 |
+
from espnet.asr.asr_utils import torch_resume
|
30 |
+
from espnet.asr.asr_utils import torch_snapshot
|
31 |
+
from espnet.asr.pytorch_backend.asr import CustomEvaluator
|
32 |
+
from espnet.asr.pytorch_backend.asr import CustomUpdater
|
33 |
+
from espnet.asr.pytorch_backend.asr import load_trained_model
|
34 |
+
import espnet.lm.pytorch_backend.extlm as extlm_pytorch
|
35 |
+
from espnet.nets.asr_interface import ASRInterface
|
36 |
+
from espnet.nets.pytorch_backend.e2e_asr_mix import pad_list
|
37 |
+
import espnet.nets.pytorch_backend.lm.default as lm_pytorch
|
38 |
+
from espnet.utils.dataset import ChainerDataLoader
|
39 |
+
from espnet.utils.dataset import TransformDataset
|
40 |
+
from espnet.utils.deterministic_utils import set_deterministic_pytorch
|
41 |
+
from espnet.utils.dynamic_import import dynamic_import
|
42 |
+
from espnet.utils.io_utils import LoadInputsAndTargets
|
43 |
+
from espnet.utils.training.batchfy import make_batchset
|
44 |
+
from espnet.utils.training.iterators import ShufflingEnabler
|
45 |
+
from espnet.utils.training.tensorboard_logger import TensorboardLogger
|
46 |
+
from espnet.utils.training.train_utils import check_early_stop
|
47 |
+
from espnet.utils.training.train_utils import set_early_stop
|
48 |
+
|
49 |
+
import matplotlib
|
50 |
+
|
51 |
+
matplotlib.use("Agg")
|
52 |
+
|
53 |
+
|
54 |
+
class CustomConverter(object):
|
55 |
+
"""Custom batch converter for Pytorch.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
subsampling_factor (int): The subsampling factor.
|
59 |
+
dtype (torch.dtype): Data type to convert.
|
60 |
+
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self, subsampling_factor=1, dtype=torch.float32, num_spkrs=2):
|
64 |
+
"""Initialize the converter."""
|
65 |
+
self.subsampling_factor = subsampling_factor
|
66 |
+
self.ignore_id = -1
|
67 |
+
self.dtype = dtype
|
68 |
+
self.num_spkrs = num_spkrs
|
69 |
+
|
70 |
+
def __call__(self, batch, device=torch.device("cpu")):
|
71 |
+
"""Transform a batch and send it to a device.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
batch (list(tuple(str, dict[str, dict[str, Any]]))): The batch to transform.
|
75 |
+
device (torch.device): The device to send to.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
tuple(torch.Tensor, torch.Tensor, torch.Tensor): Transformed batch.
|
79 |
+
|
80 |
+
"""
|
81 |
+
# batch should be located in list
|
82 |
+
assert len(batch) == 1
|
83 |
+
xs, ys = batch[0][0], batch[0][-self.num_spkrs :]
|
84 |
+
|
85 |
+
# perform subsampling
|
86 |
+
if self.subsampling_factor > 1:
|
87 |
+
xs = [x[:: self.subsampling_factor, :] for x in xs]
|
88 |
+
|
89 |
+
# get batch of lengths of input sequences
|
90 |
+
ilens = np.array([x.shape[0] for x in xs])
|
91 |
+
|
92 |
+
# perform padding and convert to tensor
|
93 |
+
# currently only support real number
|
94 |
+
if xs[0].dtype.kind == "c":
|
95 |
+
xs_pad_real = pad_list(
|
96 |
+
[torch.from_numpy(x.real).float() for x in xs], 0
|
97 |
+
).to(device, dtype=self.dtype)
|
98 |
+
xs_pad_imag = pad_list(
|
99 |
+
[torch.from_numpy(x.imag).float() for x in xs], 0
|
100 |
+
).to(device, dtype=self.dtype)
|
101 |
+
# Note(kamo):
|
102 |
+
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
|
103 |
+
# Don't create ComplexTensor and give it to E2E here
|
104 |
+
# because torch.nn.DataParallel can't handle it.
|
105 |
+
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
|
106 |
+
else:
|
107 |
+
xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(
|
108 |
+
device, dtype=self.dtype
|
109 |
+
)
|
110 |
+
|
111 |
+
ilens = torch.from_numpy(ilens).to(device)
|
112 |
+
if not isinstance(ys[0], np.ndarray):
|
113 |
+
ys_pad = []
|
114 |
+
for i in range(len(ys)): # speakers
|
115 |
+
ys_pad += [torch.from_numpy(y).long() for y in ys[i]]
|
116 |
+
ys_pad = pad_list(ys_pad, self.ignore_id)
|
117 |
+
ys_pad = (
|
118 |
+
ys_pad.view(self.num_spkrs, -1, ys_pad.size(1))
|
119 |
+
.transpose(0, 1)
|
120 |
+
.to(device)
|
121 |
+
) # (B, num_spkrs, Tmax)
|
122 |
+
else:
|
123 |
+
ys_pad = pad_list(
|
124 |
+
[torch.from_numpy(y).long() for y in ys], self.ignore_id
|
125 |
+
).to(device)
|
126 |
+
|
127 |
+
return xs_pad, ilens, ys_pad
|
128 |
+
|
129 |
+
|
130 |
+
def train(args):
|
131 |
+
"""Train with the given args.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
args (namespace): The program arguments.
|
135 |
+
|
136 |
+
"""
|
137 |
+
set_deterministic_pytorch(args)
|
138 |
+
|
139 |
+
# check cuda availability
|
140 |
+
if not torch.cuda.is_available():
|
141 |
+
logging.warning("cuda is not available")
|
142 |
+
|
143 |
+
# get input and output dimension info
|
144 |
+
with open(args.valid_json, "rb") as f:
|
145 |
+
valid_json = json.load(f)["utts"]
|
146 |
+
utts = list(valid_json.keys())
|
147 |
+
idim = int(valid_json[utts[0]]["input"][0]["shape"][-1])
|
148 |
+
odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
|
149 |
+
logging.info("#input dims : " + str(idim))
|
150 |
+
logging.info("#output dims: " + str(odim))
|
151 |
+
|
152 |
+
# specify attention, CTC, hybrid mode
|
153 |
+
if args.mtlalpha == 1.0:
|
154 |
+
mtl_mode = "ctc"
|
155 |
+
logging.info("Pure CTC mode")
|
156 |
+
elif args.mtlalpha == 0.0:
|
157 |
+
mtl_mode = "att"
|
158 |
+
logging.info("Pure attention mode")
|
159 |
+
else:
|
160 |
+
mtl_mode = "mtl"
|
161 |
+
logging.info("Multitask learning mode")
|
162 |
+
|
163 |
+
# specify model architecture
|
164 |
+
model_class = dynamic_import(args.model_module)
|
165 |
+
model = model_class(idim, odim, args)
|
166 |
+
assert isinstance(model, ASRInterface)
|
167 |
+
subsampling_factor = model.subsample[0]
|
168 |
+
|
169 |
+
if args.rnnlm is not None:
|
170 |
+
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
|
171 |
+
rnnlm = lm_pytorch.ClassifierWithState(
|
172 |
+
lm_pytorch.RNNLM(
|
173 |
+
len(args.char_list),
|
174 |
+
rnnlm_args.layer,
|
175 |
+
rnnlm_args.unit,
|
176 |
+
getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
|
177 |
+
)
|
178 |
+
)
|
179 |
+
torch.load(args.rnnlm, rnnlm)
|
180 |
+
model.rnnlm = rnnlm
|
181 |
+
|
182 |
+
# write model config
|
183 |
+
if not os.path.exists(args.outdir):
|
184 |
+
os.makedirs(args.outdir)
|
185 |
+
model_conf = args.outdir + "/model.json"
|
186 |
+
with open(model_conf, "wb") as f:
|
187 |
+
logging.info("writing a model config file to " + model_conf)
|
188 |
+
f.write(
|
189 |
+
json.dumps(
|
190 |
+
(idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
|
191 |
+
).encode("utf_8")
|
192 |
+
)
|
193 |
+
for key in sorted(vars(args).keys()):
|
194 |
+
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
|
195 |
+
|
196 |
+
reporter = model.reporter
|
197 |
+
|
198 |
+
# check the use of multi-gpu
|
199 |
+
if args.ngpu > 1:
|
200 |
+
if args.batch_size != 0:
|
201 |
+
logging.warning(
|
202 |
+
"batch size is automatically increased (%d -> %d)"
|
203 |
+
% (args.batch_size, args.batch_size * args.ngpu)
|
204 |
+
)
|
205 |
+
args.batch_size *= args.ngpu
|
206 |
+
|
207 |
+
# set torch device
|
208 |
+
device = torch.device("cuda" if args.ngpu > 0 else "cpu")
|
209 |
+
if args.train_dtype in ("float16", "float32", "float64"):
|
210 |
+
dtype = getattr(torch, args.train_dtype)
|
211 |
+
else:
|
212 |
+
dtype = torch.float32
|
213 |
+
model = model.to(device=device, dtype=dtype)
|
214 |
+
|
215 |
+
logging.warning(
|
216 |
+
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
|
217 |
+
sum(p.numel() for p in model.parameters()),
|
218 |
+
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
219 |
+
sum(p.numel() for p in model.parameters() if p.requires_grad)
|
220 |
+
* 100.0
|
221 |
+
/ sum(p.numel() for p in model.parameters()),
|
222 |
+
)
|
223 |
+
)
|
224 |
+
|
225 |
+
# Setup an optimizer
|
226 |
+
if args.opt == "adadelta":
|
227 |
+
optimizer = torch.optim.Adadelta(
|
228 |
+
model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay
|
229 |
+
)
|
230 |
+
elif args.opt == "adam":
|
231 |
+
optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay)
|
232 |
+
elif args.opt == "noam":
|
233 |
+
from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
|
234 |
+
|
235 |
+
optimizer = get_std_opt(
|
236 |
+
model.parameters(),
|
237 |
+
args.adim,
|
238 |
+
args.transformer_warmup_steps,
|
239 |
+
args.transformer_lr,
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
raise NotImplementedError("unknown optimizer: " + args.opt)
|
243 |
+
|
244 |
+
# setup apex.amp
|
245 |
+
if args.train_dtype in ("O0", "O1", "O2", "O3"):
|
246 |
+
try:
|
247 |
+
from apex import amp
|
248 |
+
except ImportError as e:
|
249 |
+
logging.error(
|
250 |
+
f"You need to install apex for --train-dtype {args.train_dtype}. "
|
251 |
+
"See https://github.com/NVIDIA/apex#linux"
|
252 |
+
)
|
253 |
+
raise e
|
254 |
+
if args.opt == "noam":
|
255 |
+
model, optimizer.optimizer = amp.initialize(
|
256 |
+
model, optimizer.optimizer, opt_level=args.train_dtype
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
model, optimizer = amp.initialize(
|
260 |
+
model, optimizer, opt_level=args.train_dtype
|
261 |
+
)
|
262 |
+
use_apex = True
|
263 |
+
else:
|
264 |
+
use_apex = False
|
265 |
+
|
266 |
+
# FIXME: TOO DIRTY HACK
|
267 |
+
setattr(optimizer, "target", reporter)
|
268 |
+
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
|
269 |
+
|
270 |
+
# Setup a converter
|
271 |
+
converter = CustomConverter(
|
272 |
+
subsampling_factor=subsampling_factor, dtype=dtype, num_spkrs=args.num_spkrs
|
273 |
+
)
|
274 |
+
|
275 |
+
# read json data
|
276 |
+
with open(args.train_json, "rb") as f:
|
277 |
+
train_json = json.load(f)["utts"]
|
278 |
+
with open(args.valid_json, "rb") as f:
|
279 |
+
valid_json = json.load(f)["utts"]
|
280 |
+
|
281 |
+
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
|
282 |
+
# make minibatch list (variable length)
|
283 |
+
train = make_batchset(
|
284 |
+
train_json,
|
285 |
+
args.batch_size,
|
286 |
+
args.maxlen_in,
|
287 |
+
args.maxlen_out,
|
288 |
+
args.minibatches,
|
289 |
+
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
|
290 |
+
shortest_first=use_sortagrad,
|
291 |
+
count=args.batch_count,
|
292 |
+
batch_bins=args.batch_bins,
|
293 |
+
batch_frames_in=args.batch_frames_in,
|
294 |
+
batch_frames_out=args.batch_frames_out,
|
295 |
+
batch_frames_inout=args.batch_frames_inout,
|
296 |
+
iaxis=0,
|
297 |
+
oaxis=-1,
|
298 |
+
)
|
299 |
+
valid = make_batchset(
|
300 |
+
valid_json,
|
301 |
+
args.batch_size,
|
302 |
+
args.maxlen_in,
|
303 |
+
args.maxlen_out,
|
304 |
+
args.minibatches,
|
305 |
+
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
|
306 |
+
count=args.batch_count,
|
307 |
+
batch_bins=args.batch_bins,
|
308 |
+
batch_frames_in=args.batch_frames_in,
|
309 |
+
batch_frames_out=args.batch_frames_out,
|
310 |
+
batch_frames_inout=args.batch_frames_inout,
|
311 |
+
iaxis=0,
|
312 |
+
oaxis=-1,
|
313 |
+
)
|
314 |
+
|
315 |
+
load_tr = LoadInputsAndTargets(
|
316 |
+
mode="asr",
|
317 |
+
load_output=True,
|
318 |
+
preprocess_conf=args.preprocess_conf,
|
319 |
+
preprocess_args={"train": True}, # Switch the mode of preprocessing
|
320 |
+
)
|
321 |
+
load_cv = LoadInputsAndTargets(
|
322 |
+
mode="asr",
|
323 |
+
load_output=True,
|
324 |
+
preprocess_conf=args.preprocess_conf,
|
325 |
+
preprocess_args={"train": False}, # Switch the mode of preprocessing
|
326 |
+
)
|
327 |
+
# hack to make batchsize argument as 1
|
328 |
+
# actual bathsize is included in a list
|
329 |
+
# default collate function converts numpy array to pytorch tensor
|
330 |
+
# we used an empty collate function instead which returns list
|
331 |
+
train_iter = {
|
332 |
+
"main": ChainerDataLoader(
|
333 |
+
dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
|
334 |
+
batch_size=1,
|
335 |
+
num_workers=args.n_iter_processes,
|
336 |
+
shuffle=True,
|
337 |
+
collate_fn=lambda x: x[0],
|
338 |
+
)
|
339 |
+
}
|
340 |
+
valid_iter = {
|
341 |
+
"main": ChainerDataLoader(
|
342 |
+
dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
|
343 |
+
batch_size=1,
|
344 |
+
shuffle=False,
|
345 |
+
collate_fn=lambda x: x[0],
|
346 |
+
num_workers=args.n_iter_processes,
|
347 |
+
)
|
348 |
+
}
|
349 |
+
|
350 |
+
# Set up a trainer
|
351 |
+
updater = CustomUpdater(
|
352 |
+
model,
|
353 |
+
args.grad_clip,
|
354 |
+
train_iter,
|
355 |
+
optimizer,
|
356 |
+
device,
|
357 |
+
args.ngpu,
|
358 |
+
args.grad_noise,
|
359 |
+
args.accum_grad,
|
360 |
+
use_apex=use_apex,
|
361 |
+
)
|
362 |
+
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
|
363 |
+
|
364 |
+
if use_sortagrad:
|
365 |
+
trainer.extend(
|
366 |
+
ShufflingEnabler([train_iter]),
|
367 |
+
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
|
368 |
+
)
|
369 |
+
|
370 |
+
# Resume from a snapshot
|
371 |
+
if args.resume:
|
372 |
+
logging.info("resumed from %s" % args.resume)
|
373 |
+
torch_resume(args.resume, trainer)
|
374 |
+
|
375 |
+
# Evaluate the model with the test dataset for each epoch
|
376 |
+
trainer.extend(CustomEvaluator(model, valid_iter, reporter, device, args.ngpu))
|
377 |
+
|
378 |
+
# Save attention weight each epoch
|
379 |
+
if args.num_save_attention > 0 and args.mtlalpha != 1.0:
|
380 |
+
data = sorted(
|
381 |
+
list(valid_json.items())[: args.num_save_attention],
|
382 |
+
key=lambda x: int(x[1]["input"][0]["shape"][1]),
|
383 |
+
reverse=True,
|
384 |
+
)
|
385 |
+
if hasattr(model, "module"):
|
386 |
+
att_vis_fn = model.module.calculate_all_attentions
|
387 |
+
plot_class = model.module.attention_plot_class
|
388 |
+
else:
|
389 |
+
att_vis_fn = model.calculate_all_attentions
|
390 |
+
plot_class = model.attention_plot_class
|
391 |
+
att_reporter = plot_class(
|
392 |
+
att_vis_fn,
|
393 |
+
data,
|
394 |
+
args.outdir + "/att_ws",
|
395 |
+
converter=converter,
|
396 |
+
transform=load_cv,
|
397 |
+
device=device,
|
398 |
+
)
|
399 |
+
trainer.extend(att_reporter, trigger=(1, "epoch"))
|
400 |
+
else:
|
401 |
+
att_reporter = None
|
402 |
+
|
403 |
+
# Make a plot for training and validation values
|
404 |
+
trainer.extend(
|
405 |
+
extensions.PlotReport(
|
406 |
+
[
|
407 |
+
"main/loss",
|
408 |
+
"validation/main/loss",
|
409 |
+
"main/loss_ctc",
|
410 |
+
"validation/main/loss_ctc",
|
411 |
+
"main/loss_att",
|
412 |
+
"validation/main/loss_att",
|
413 |
+
],
|
414 |
+
"epoch",
|
415 |
+
file_name="loss.png",
|
416 |
+
)
|
417 |
+
)
|
418 |
+
trainer.extend(
|
419 |
+
extensions.PlotReport(
|
420 |
+
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
|
421 |
+
)
|
422 |
+
)
|
423 |
+
trainer.extend(
|
424 |
+
extensions.PlotReport(
|
425 |
+
["main/cer_ctc", "validation/main/cer_ctc"], "epoch", file_name="cer.png"
|
426 |
+
)
|
427 |
+
)
|
428 |
+
|
429 |
+
# Save best models
|
430 |
+
trainer.extend(
|
431 |
+
snapshot_object(model, "model.loss.best"),
|
432 |
+
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
|
433 |
+
)
|
434 |
+
if mtl_mode != "ctc":
|
435 |
+
trainer.extend(
|
436 |
+
snapshot_object(model, "model.acc.best"),
|
437 |
+
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
|
438 |
+
)
|
439 |
+
|
440 |
+
# save snapshot which contains model and optimizer states
|
441 |
+
trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
|
442 |
+
|
443 |
+
# epsilon decay in the optimizer
|
444 |
+
if args.opt == "adadelta":
|
445 |
+
if args.criterion == "acc" and mtl_mode != "ctc":
|
446 |
+
trainer.extend(
|
447 |
+
restore_snapshot(
|
448 |
+
model, args.outdir + "/model.acc.best", load_fn=torch_load
|
449 |
+
),
|
450 |
+
trigger=CompareValueTrigger(
|
451 |
+
"validation/main/acc",
|
452 |
+
lambda best_value, current_value: best_value > current_value,
|
453 |
+
),
|
454 |
+
)
|
455 |
+
trainer.extend(
|
456 |
+
adadelta_eps_decay(args.eps_decay),
|
457 |
+
trigger=CompareValueTrigger(
|
458 |
+
"validation/main/acc",
|
459 |
+
lambda best_value, current_value: best_value > current_value,
|
460 |
+
),
|
461 |
+
)
|
462 |
+
elif args.criterion == "loss":
|
463 |
+
trainer.extend(
|
464 |
+
restore_snapshot(
|
465 |
+
model, args.outdir + "/model.loss.best", load_fn=torch_load
|
466 |
+
),
|
467 |
+
trigger=CompareValueTrigger(
|
468 |
+
"validation/main/loss",
|
469 |
+
lambda best_value, current_value: best_value < current_value,
|
470 |
+
),
|
471 |
+
)
|
472 |
+
trainer.extend(
|
473 |
+
adadelta_eps_decay(args.eps_decay),
|
474 |
+
trigger=CompareValueTrigger(
|
475 |
+
"validation/main/loss",
|
476 |
+
lambda best_value, current_value: best_value < current_value,
|
477 |
+
),
|
478 |
+
)
|
479 |
+
|
480 |
+
# Write a log of evaluation statistics for each epoch
|
481 |
+
trainer.extend(
|
482 |
+
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
|
483 |
+
)
|
484 |
+
report_keys = [
|
485 |
+
"epoch",
|
486 |
+
"iteration",
|
487 |
+
"main/loss",
|
488 |
+
"main/loss_ctc",
|
489 |
+
"main/loss_att",
|
490 |
+
"validation/main/loss",
|
491 |
+
"validation/main/loss_ctc",
|
492 |
+
"validation/main/loss_att",
|
493 |
+
"main/acc",
|
494 |
+
"validation/main/acc",
|
495 |
+
"main/cer_ctc",
|
496 |
+
"validation/main/cer_ctc",
|
497 |
+
"elapsed_time",
|
498 |
+
]
|
499 |
+
if args.opt == "adadelta":
|
500 |
+
trainer.extend(
|
501 |
+
extensions.observe_value(
|
502 |
+
"eps",
|
503 |
+
lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
|
504 |
+
"eps"
|
505 |
+
],
|
506 |
+
),
|
507 |
+
trigger=(args.report_interval_iters, "iteration"),
|
508 |
+
)
|
509 |
+
report_keys.append("eps")
|
510 |
+
if args.report_cer:
|
511 |
+
report_keys.append("validation/main/cer")
|
512 |
+
if args.report_wer:
|
513 |
+
report_keys.append("validation/main/wer")
|
514 |
+
trainer.extend(
|
515 |
+
extensions.PrintReport(report_keys),
|
516 |
+
trigger=(args.report_interval_iters, "iteration"),
|
517 |
+
)
|
518 |
+
|
519 |
+
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
|
520 |
+
set_early_stop(trainer, args)
|
521 |
+
|
522 |
+
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
|
523 |
+
trainer.extend(
|
524 |
+
TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter),
|
525 |
+
trigger=(args.report_interval_iters, "iteration"),
|
526 |
+
)
|
527 |
+
# Run the training
|
528 |
+
trainer.run()
|
529 |
+
check_early_stop(trainer, args.epochs)
|
530 |
+
|
531 |
+
|
532 |
+
def recog(args):
|
533 |
+
"""Decode with the given args.
|
534 |
+
|
535 |
+
Args:
|
536 |
+
args (namespace): The program arguments.
|
537 |
+
|
538 |
+
"""
|
539 |
+
set_deterministic_pytorch(args)
|
540 |
+
model, train_args = load_trained_model(args.model)
|
541 |
+
assert isinstance(model, ASRInterface)
|
542 |
+
model.recog_args = args
|
543 |
+
|
544 |
+
# read rnnlm
|
545 |
+
if args.rnnlm:
|
546 |
+
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
|
547 |
+
if getattr(rnnlm_args, "model_module", "default") != "default":
|
548 |
+
raise ValueError(
|
549 |
+
"use '--api v2' option to decode with non-default language model"
|
550 |
+
)
|
551 |
+
rnnlm = lm_pytorch.ClassifierWithState(
|
552 |
+
lm_pytorch.RNNLM(
|
553 |
+
len(train_args.char_list),
|
554 |
+
rnnlm_args.layer,
|
555 |
+
rnnlm_args.unit,
|
556 |
+
getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
|
557 |
+
)
|
558 |
+
)
|
559 |
+
torch_load(args.rnnlm, rnnlm)
|
560 |
+
rnnlm.eval()
|
561 |
+
else:
|
562 |
+
rnnlm = None
|
563 |
+
|
564 |
+
if args.word_rnnlm:
|
565 |
+
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
|
566 |
+
word_dict = rnnlm_args.char_list_dict
|
567 |
+
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
|
568 |
+
word_rnnlm = lm_pytorch.ClassifierWithState(
|
569 |
+
lm_pytorch.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)
|
570 |
+
)
|
571 |
+
torch_load(args.word_rnnlm, word_rnnlm)
|
572 |
+
word_rnnlm.eval()
|
573 |
+
|
574 |
+
if rnnlm is not None:
|
575 |
+
rnnlm = lm_pytorch.ClassifierWithState(
|
576 |
+
extlm_pytorch.MultiLevelLM(
|
577 |
+
word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
|
578 |
+
)
|
579 |
+
)
|
580 |
+
else:
|
581 |
+
rnnlm = lm_pytorch.ClassifierWithState(
|
582 |
+
extlm_pytorch.LookAheadWordLM(
|
583 |
+
word_rnnlm.predictor, word_dict, char_dict
|
584 |
+
)
|
585 |
+
)
|
586 |
+
|
587 |
+
# gpu
|
588 |
+
if args.ngpu == 1:
|
589 |
+
gpu_id = list(range(args.ngpu))
|
590 |
+
logging.info("gpu id: " + str(gpu_id))
|
591 |
+
model.cuda()
|
592 |
+
if rnnlm:
|
593 |
+
rnnlm.cuda()
|
594 |
+
|
595 |
+
# read json data
|
596 |
+
with open(args.recog_json, "rb") as f:
|
597 |
+
js = json.load(f)["utts"]
|
598 |
+
new_js = {}
|
599 |
+
|
600 |
+
load_inputs_and_targets = LoadInputsAndTargets(
|
601 |
+
mode="asr",
|
602 |
+
load_output=False,
|
603 |
+
sort_in_input_length=False,
|
604 |
+
preprocess_conf=train_args.preprocess_conf
|
605 |
+
if args.preprocess_conf is None
|
606 |
+
else args.preprocess_conf,
|
607 |
+
preprocess_args={"train": False},
|
608 |
+
)
|
609 |
+
|
610 |
+
if args.batchsize == 0:
|
611 |
+
with torch.no_grad():
|
612 |
+
for idx, name in enumerate(js.keys(), 1):
|
613 |
+
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
|
614 |
+
batch = [(name, js[name])]
|
615 |
+
feat = load_inputs_and_targets(batch)[0][0]
|
616 |
+
nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm)
|
617 |
+
new_js[name] = add_results_to_json(
|
618 |
+
js[name], nbest_hyps, train_args.char_list
|
619 |
+
)
|
620 |
+
|
621 |
+
else:
|
622 |
+
|
623 |
+
def grouper(n, iterable, fillvalue=None):
|
624 |
+
kargs = [iter(iterable)] * n
|
625 |
+
return zip_longest(*kargs, fillvalue=fillvalue)
|
626 |
+
|
627 |
+
# sort data if batchsize > 1
|
628 |
+
keys = list(js.keys())
|
629 |
+
if args.batchsize > 1:
|
630 |
+
feat_lens = [js[key]["input"][0]["shape"][0] for key in keys]
|
631 |
+
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
|
632 |
+
keys = [keys[i] for i in sorted_index]
|
633 |
+
|
634 |
+
with torch.no_grad():
|
635 |
+
for names in grouper(args.batchsize, keys, None):
|
636 |
+
names = [name for name in names if name]
|
637 |
+
batch = [(name, js[name]) for name in names]
|
638 |
+
feats = load_inputs_and_targets(batch)[0]
|
639 |
+
nbest_hyps = model.recognize_batch(
|
640 |
+
feats, args, train_args.char_list, rnnlm=rnnlm
|
641 |
+
)
|
642 |
+
|
643 |
+
for i, name in enumerate(names):
|
644 |
+
nbest_hyp = [hyp[i] for hyp in nbest_hyps]
|
645 |
+
new_js[name] = add_results_to_json(
|
646 |
+
js[name], nbest_hyp, train_args.char_list
|
647 |
+
)
|
648 |
+
|
649 |
+
with open(args.result_label, "wb") as f:
|
650 |
+
f.write(
|
651 |
+
json.dumps(
|
652 |
+
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
|
653 |
+
).encode("utf_8")
|
654 |
+
)
|
espnet/asr/pytorch_backend/recog.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`."""
|
2 |
+
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from espnet.asr.asr_utils import add_results_to_json
|
9 |
+
from espnet.asr.asr_utils import get_model_conf
|
10 |
+
from espnet.asr.asr_utils import torch_load
|
11 |
+
from espnet.asr.pytorch_backend.asr import load_trained_model
|
12 |
+
from espnet.nets.asr_interface import ASRInterface
|
13 |
+
from espnet.nets.batch_beam_search import BatchBeamSearch
|
14 |
+
from espnet.nets.beam_search import BeamSearch
|
15 |
+
from espnet.nets.lm_interface import dynamic_import_lm
|
16 |
+
from espnet.nets.scorer_interface import BatchScorerInterface
|
17 |
+
from espnet.nets.scorers.length_bonus import LengthBonus
|
18 |
+
from espnet.utils.deterministic_utils import set_deterministic_pytorch
|
19 |
+
from espnet.utils.io_utils import LoadInputsAndTargets
|
20 |
+
|
21 |
+
|
22 |
+
def recog_v2(args):
|
23 |
+
"""Decode with custom models that implements ScorerInterface.
|
24 |
+
|
25 |
+
Notes:
|
26 |
+
The previous backend espnet.asr.pytorch_backend.asr.recog
|
27 |
+
only supports E2E and RNNLM
|
28 |
+
|
29 |
+
Args:
|
30 |
+
args (namespace): The program arguments.
|
31 |
+
See py:func:`espnet.bin.asr_recog.get_parser` for details
|
32 |
+
|
33 |
+
"""
|
34 |
+
logging.warning("experimental API for custom LMs is selected by --api v2")
|
35 |
+
if args.batchsize > 1:
|
36 |
+
raise NotImplementedError("multi-utt batch decoding is not implemented")
|
37 |
+
if args.streaming_mode is not None:
|
38 |
+
raise NotImplementedError("streaming mode is not implemented")
|
39 |
+
if args.word_rnnlm:
|
40 |
+
raise NotImplementedError("word LM is not implemented")
|
41 |
+
|
42 |
+
set_deterministic_pytorch(args)
|
43 |
+
model, train_args = load_trained_model(args.model)
|
44 |
+
assert isinstance(model, ASRInterface)
|
45 |
+
model.eval()
|
46 |
+
|
47 |
+
load_inputs_and_targets = LoadInputsAndTargets(
|
48 |
+
mode="asr",
|
49 |
+
load_output=False,
|
50 |
+
sort_in_input_length=False,
|
51 |
+
preprocess_conf=train_args.preprocess_conf
|
52 |
+
if args.preprocess_conf is None
|
53 |
+
else args.preprocess_conf,
|
54 |
+
preprocess_args={"train": False},
|
55 |
+
)
|
56 |
+
|
57 |
+
if args.rnnlm:
|
58 |
+
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
|
59 |
+
# NOTE: for a compatibility with less than 0.5.0 version models
|
60 |
+
lm_model_module = getattr(lm_args, "model_module", "default")
|
61 |
+
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
|
62 |
+
lm = lm_class(len(train_args.char_list), lm_args)
|
63 |
+
torch_load(args.rnnlm, lm)
|
64 |
+
lm.eval()
|
65 |
+
else:
|
66 |
+
lm = None
|
67 |
+
|
68 |
+
if args.ngram_model:
|
69 |
+
from espnet.nets.scorers.ngram import NgramFullScorer
|
70 |
+
from espnet.nets.scorers.ngram import NgramPartScorer
|
71 |
+
|
72 |
+
if args.ngram_scorer == "full":
|
73 |
+
ngram = NgramFullScorer(args.ngram_model, train_args.char_list)
|
74 |
+
else:
|
75 |
+
ngram = NgramPartScorer(args.ngram_model, train_args.char_list)
|
76 |
+
else:
|
77 |
+
ngram = None
|
78 |
+
|
79 |
+
scorers = model.scorers()
|
80 |
+
scorers["lm"] = lm
|
81 |
+
scorers["ngram"] = ngram
|
82 |
+
scorers["length_bonus"] = LengthBonus(len(train_args.char_list))
|
83 |
+
weights = dict(
|
84 |
+
decoder=1.0 - args.ctc_weight,
|
85 |
+
ctc=args.ctc_weight,
|
86 |
+
lm=args.lm_weight,
|
87 |
+
ngram=args.ngram_weight,
|
88 |
+
length_bonus=args.penalty,
|
89 |
+
)
|
90 |
+
beam_search = BeamSearch(
|
91 |
+
beam_size=args.beam_size,
|
92 |
+
vocab_size=len(train_args.char_list),
|
93 |
+
weights=weights,
|
94 |
+
scorers=scorers,
|
95 |
+
sos=model.sos,
|
96 |
+
eos=model.eos,
|
97 |
+
token_list=train_args.char_list,
|
98 |
+
pre_beam_score_key=None if args.ctc_weight == 1.0 else "full",
|
99 |
+
)
|
100 |
+
# TODO(karita): make all scorers batchfied
|
101 |
+
if args.batchsize == 1:
|
102 |
+
non_batch = [
|
103 |
+
k
|
104 |
+
for k, v in beam_search.full_scorers.items()
|
105 |
+
if not isinstance(v, BatchScorerInterface)
|
106 |
+
]
|
107 |
+
if len(non_batch) == 0:
|
108 |
+
beam_search.__class__ = BatchBeamSearch
|
109 |
+
logging.info("BatchBeamSearch implementation is selected.")
|
110 |
+
else:
|
111 |
+
logging.warning(
|
112 |
+
f"As non-batch scorers {non_batch} are found, "
|
113 |
+
f"fall back to non-batch implementation."
|
114 |
+
)
|
115 |
+
|
116 |
+
if args.ngpu > 1:
|
117 |
+
raise NotImplementedError("only single GPU decoding is supported")
|
118 |
+
if args.ngpu == 1:
|
119 |
+
device = "cuda"
|
120 |
+
else:
|
121 |
+
device = "cpu"
|
122 |
+
dtype = getattr(torch, args.dtype)
|
123 |
+
logging.info(f"Decoding device={device}, dtype={dtype}")
|
124 |
+
model.to(device=device, dtype=dtype).eval()
|
125 |
+
beam_search.to(device=device, dtype=dtype).eval()
|
126 |
+
|
127 |
+
# read json data
|
128 |
+
with open(args.recog_json, "rb") as f:
|
129 |
+
js = json.load(f)["utts"]
|
130 |
+
new_js = {}
|
131 |
+
with torch.no_grad():
|
132 |
+
for idx, name in enumerate(js.keys(), 1):
|
133 |
+
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
|
134 |
+
batch = [(name, js[name])]
|
135 |
+
feat = load_inputs_and_targets(batch)[0][0]
|
136 |
+
enc = model.encode(torch.as_tensor(feat).to(device=device, dtype=dtype))
|
137 |
+
nbest_hyps = beam_search(
|
138 |
+
x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio
|
139 |
+
)
|
140 |
+
nbest_hyps = [
|
141 |
+
h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)]
|
142 |
+
]
|
143 |
+
new_js[name] = add_results_to_json(
|
144 |
+
js[name], nbest_hyps, train_args.char_list
|
145 |
+
)
|
146 |
+
|
147 |
+
with open(args.result_label, "wb") as f:
|
148 |
+
f.write(
|
149 |
+
json.dumps(
|
150 |
+
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
|
151 |
+
).encode("utf_8")
|
152 |
+
)
|
espnet/bin/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/bin/asr_align.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# encoding: utf-8
|
3 |
+
|
4 |
+
# Copyright 2020 Johns Hopkins University (Xuankai Chang)
|
5 |
+
# 2020, Technische Universität München; Dominik Winkelbauer, Ludwig Kürzinger
|
6 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
7 |
+
|
8 |
+
"""
|
9 |
+
This program performs CTC segmentation to align utterances within audio files.
|
10 |
+
|
11 |
+
Inputs:
|
12 |
+
`--data-json`:
|
13 |
+
A json containing list of utterances and audio files
|
14 |
+
`--model`:
|
15 |
+
An already trained ASR model
|
16 |
+
|
17 |
+
Output:
|
18 |
+
`--output`:
|
19 |
+
A plain `segments` file with utterance positions in the audio files.
|
20 |
+
|
21 |
+
Selected parameters:
|
22 |
+
`--min-window-size`:
|
23 |
+
Minimum window size considered for a single utterance. The current default value
|
24 |
+
should be OK in most cases. Larger values might give better results; too large
|
25 |
+
values cause IndexErrors.
|
26 |
+
`--subsampling-factor`:
|
27 |
+
If the encoder sub-samples its input, the number of frames at the CTC layer is
|
28 |
+
reduced by this factor.
|
29 |
+
`--frame-duration`:
|
30 |
+
This is the non-overlapping duration of a single frame in milliseconds (the
|
31 |
+
inverse of frames per millisecond).
|
32 |
+
`--set-blank`:
|
33 |
+
In the rare case that the blank token has not the index 0 in the character
|
34 |
+
dictionary, this parameter sets the index of the blank token.
|
35 |
+
`--gratis-blank`:
|
36 |
+
Sets the transition cost for blank tokens to zero. Useful if there are longer
|
37 |
+
unrelated segments between segments.
|
38 |
+
`--replace-spaces-with-blanks`:
|
39 |
+
Spaces are replaced with blanks. Helps to model pauses between words. May
|
40 |
+
increase length of ground truth. May lead to misaligned segments when combined
|
41 |
+
with the option `--gratis-blank`.
|
42 |
+
"""
|
43 |
+
|
44 |
+
import configargparse
|
45 |
+
import logging
|
46 |
+
import os
|
47 |
+
import sys
|
48 |
+
|
49 |
+
# imports for inference
|
50 |
+
from espnet.asr.pytorch_backend.asr_init import load_trained_model
|
51 |
+
from espnet.nets.asr_interface import ASRInterface
|
52 |
+
from espnet.utils.io_utils import LoadInputsAndTargets
|
53 |
+
import json
|
54 |
+
import torch
|
55 |
+
|
56 |
+
# imports for CTC segmentation
|
57 |
+
from ctc_segmentation import ctc_segmentation
|
58 |
+
from ctc_segmentation import CtcSegmentationParameters
|
59 |
+
from ctc_segmentation import determine_utterance_segments
|
60 |
+
from ctc_segmentation import prepare_text
|
61 |
+
|
62 |
+
|
63 |
+
# NOTE: you need this func to generate our sphinx doc
|
64 |
+
def get_parser():
|
65 |
+
"""Get default arguments."""
|
66 |
+
parser = configargparse.ArgumentParser(
|
67 |
+
description="Align text to audio using CTC segmentation."
|
68 |
+
"using a pre-trained speech recognition model.",
|
69 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
70 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
71 |
+
)
|
72 |
+
# general configuration
|
73 |
+
parser.add("--config", is_config_file=True, help="Decoding config file path.")
|
74 |
+
parser.add_argument(
|
75 |
+
"--ngpu", type=int, default=0, help="Number of GPUs (max. 1 is supported)"
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--dtype",
|
79 |
+
choices=("float16", "float32", "float64"),
|
80 |
+
default="float32",
|
81 |
+
help="Float precision (only available in --api v2)",
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--backend",
|
85 |
+
type=str,
|
86 |
+
default="pytorch",
|
87 |
+
choices=["pytorch"],
|
88 |
+
help="Backend library",
|
89 |
+
)
|
90 |
+
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
|
91 |
+
parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
|
92 |
+
parser.add_argument(
|
93 |
+
"--preprocess-conf",
|
94 |
+
type=str,
|
95 |
+
default=None,
|
96 |
+
help="The configuration file for the pre-processing",
|
97 |
+
)
|
98 |
+
# task related
|
99 |
+
parser.add_argument(
|
100 |
+
"--data-json", type=str, help="Json of recognition data for audio and text"
|
101 |
+
)
|
102 |
+
parser.add_argument("--utt-text", type=str, help="Text separated into utterances")
|
103 |
+
# model (parameter) related
|
104 |
+
parser.add_argument(
|
105 |
+
"--model", type=str, required=True, help="Model file parameters to read"
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--model-conf", type=str, default=None, help="Model config file"
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--num-encs", default=1, type=int, help="Number of encoders in the model."
|
112 |
+
)
|
113 |
+
# ctc-segmentation related
|
114 |
+
parser.add_argument(
|
115 |
+
"--subsampling-factor",
|
116 |
+
type=int,
|
117 |
+
default=None,
|
118 |
+
help="Subsampling factor."
|
119 |
+
" If the encoder sub-samples its input, the number of frames at the CTC layer"
|
120 |
+
" is reduced by this factor. For example, a BLSTMP with subsampling 1_2_2_1_1"
|
121 |
+
" has a subsampling factor of 4.",
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--frame-duration",
|
125 |
+
type=int,
|
126 |
+
default=None,
|
127 |
+
help="Non-overlapping duration of a single frame in milliseconds.",
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--min-window-size",
|
131 |
+
type=int,
|
132 |
+
default=None,
|
133 |
+
help="Minimum window size considered for utterance.",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--max-window-size",
|
137 |
+
type=int,
|
138 |
+
default=None,
|
139 |
+
help="Maximum window size considered for utterance.",
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--use-dict-blank",
|
143 |
+
type=int,
|
144 |
+
default=None,
|
145 |
+
help="DEPRECATED.",
|
146 |
+
)
|
147 |
+
parser.add_argument(
|
148 |
+
"--set-blank",
|
149 |
+
type=int,
|
150 |
+
default=None,
|
151 |
+
help="Index of model dictionary for blank token (default: 0).",
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--gratis-blank",
|
155 |
+
type=int,
|
156 |
+
default=None,
|
157 |
+
help="Set the transition cost of the blank token to zero. Audio sections"
|
158 |
+
" labeled with blank tokens can then be skipped without penalty. Useful"
|
159 |
+
" if there are unrelated audio segments between utterances.",
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--replace-spaces-with-blanks",
|
163 |
+
type=int,
|
164 |
+
default=None,
|
165 |
+
help="Fill blanks in between words to better model pauses between words."
|
166 |
+
" Segments can be misaligned if this option is combined with --gratis-blank."
|
167 |
+
" May increase length of ground truth.",
|
168 |
+
)
|
169 |
+
parser.add_argument(
|
170 |
+
"--scoring-length",
|
171 |
+
type=int,
|
172 |
+
default=None,
|
173 |
+
help="Changes partitioning length L for calculation of the confidence score.",
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--output",
|
177 |
+
type=configargparse.FileType("w"),
|
178 |
+
required=True,
|
179 |
+
help="Output segments file",
|
180 |
+
)
|
181 |
+
return parser
|
182 |
+
|
183 |
+
|
184 |
+
def main(args):
|
185 |
+
"""Run the main decoding function."""
|
186 |
+
parser = get_parser()
|
187 |
+
args, extra = parser.parse_known_args(args)
|
188 |
+
# logging info
|
189 |
+
if args.verbose == 1:
|
190 |
+
logging.basicConfig(
|
191 |
+
level=logging.INFO,
|
192 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
193 |
+
)
|
194 |
+
elif args.verbose == 2:
|
195 |
+
logging.basicConfig(
|
196 |
+
level=logging.DEBUG,
|
197 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
198 |
+
)
|
199 |
+
else:
|
200 |
+
logging.basicConfig(
|
201 |
+
level=logging.WARN,
|
202 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
203 |
+
)
|
204 |
+
logging.warning("Skip DEBUG/INFO messages")
|
205 |
+
if args.ngpu == 0 and args.dtype == "float16":
|
206 |
+
raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.")
|
207 |
+
# check CUDA_VISIBLE_DEVICES
|
208 |
+
device = "cpu"
|
209 |
+
if args.ngpu == 1:
|
210 |
+
device = "cuda"
|
211 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
212 |
+
if cvd is None:
|
213 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
214 |
+
elif args.ngpu > 1:
|
215 |
+
logging.error("Decoding only supports ngpu=1.")
|
216 |
+
sys.exit(1)
|
217 |
+
# display PYTHONPATH
|
218 |
+
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
|
219 |
+
# recog
|
220 |
+
logging.info("backend = " + args.backend)
|
221 |
+
if args.backend == "pytorch":
|
222 |
+
ctc_align(args, device)
|
223 |
+
else:
|
224 |
+
raise ValueError("Only pytorch is supported.")
|
225 |
+
sys.exit(0)
|
226 |
+
|
227 |
+
|
228 |
+
def ctc_align(args, device):
|
229 |
+
"""ESPnet-specific interface for CTC segmentation.
|
230 |
+
|
231 |
+
Parses configuration, infers the CTC posterior probabilities,
|
232 |
+
and then aligns start and end of utterances using CTC segmentation.
|
233 |
+
Results are written to the output file given in the args.
|
234 |
+
|
235 |
+
:param args: given configuration
|
236 |
+
:param device: for inference; one of ['cuda', 'cpu']
|
237 |
+
:return: 0 on success
|
238 |
+
"""
|
239 |
+
model, train_args = load_trained_model(args.model)
|
240 |
+
assert isinstance(model, ASRInterface)
|
241 |
+
load_inputs_and_targets = LoadInputsAndTargets(
|
242 |
+
mode="asr",
|
243 |
+
load_output=True,
|
244 |
+
sort_in_input_length=False,
|
245 |
+
preprocess_conf=train_args.preprocess_conf
|
246 |
+
if args.preprocess_conf is None
|
247 |
+
else args.preprocess_conf,
|
248 |
+
preprocess_args={"train": False},
|
249 |
+
)
|
250 |
+
logging.info(f"Decoding device={device}")
|
251 |
+
# Warn for nets with high memory consumption on long audio files
|
252 |
+
if hasattr(model, "enc"):
|
253 |
+
encoder_module = model.enc.__class__.__module__
|
254 |
+
elif hasattr(model, "encoder"):
|
255 |
+
encoder_module = model.encoder.__class__.__module__
|
256 |
+
else:
|
257 |
+
encoder_module = "Unknown"
|
258 |
+
logging.info(f"Encoder module: {encoder_module}")
|
259 |
+
logging.info(f"CTC module: {model.ctc.__class__.__module__}")
|
260 |
+
if "rnn" not in encoder_module:
|
261 |
+
logging.warning("No BLSTM model detected; memory consumption may be high.")
|
262 |
+
model.to(device=device).eval()
|
263 |
+
# read audio and text json data
|
264 |
+
with open(args.data_json, "rb") as f:
|
265 |
+
js = json.load(f)["utts"]
|
266 |
+
with open(args.utt_text, "r", encoding="utf-8") as f:
|
267 |
+
lines = f.readlines()
|
268 |
+
i = 0
|
269 |
+
text = {}
|
270 |
+
segment_names = {}
|
271 |
+
for name in js.keys():
|
272 |
+
text_per_audio = []
|
273 |
+
segment_names_per_audio = []
|
274 |
+
while i < len(lines) and lines[i].startswith(name):
|
275 |
+
text_per_audio.append(lines[i][lines[i].find(" ") + 1 :])
|
276 |
+
segment_names_per_audio.append(lines[i][: lines[i].find(" ")])
|
277 |
+
i += 1
|
278 |
+
text[name] = text_per_audio
|
279 |
+
segment_names[name] = segment_names_per_audio
|
280 |
+
# apply configuration
|
281 |
+
config = CtcSegmentationParameters()
|
282 |
+
if args.subsampling_factor is not None:
|
283 |
+
config.subsampling_factor = args.subsampling_factor
|
284 |
+
if args.frame_duration is not None:
|
285 |
+
config.frame_duration_ms = args.frame_duration
|
286 |
+
if args.min_window_size is not None:
|
287 |
+
config.min_window_size = args.min_window_size
|
288 |
+
if args.max_window_size is not None:
|
289 |
+
config.max_window_size = args.max_window_size
|
290 |
+
config.char_list = train_args.char_list
|
291 |
+
if args.use_dict_blank is not None:
|
292 |
+
logging.warning(
|
293 |
+
"The option --use-dict-blank is deprecated. If needed,"
|
294 |
+
" use --set-blank instead."
|
295 |
+
)
|
296 |
+
if args.set_blank is not None:
|
297 |
+
config.blank = args.set_blank
|
298 |
+
if args.replace_spaces_with_blanks is not None:
|
299 |
+
if args.replace_spaces_with_blanks:
|
300 |
+
config.replace_spaces_with_blanks = True
|
301 |
+
else:
|
302 |
+
config.replace_spaces_with_blanks = False
|
303 |
+
if args.gratis_blank:
|
304 |
+
config.blank_transition_cost_zero = True
|
305 |
+
if config.blank_transition_cost_zero and args.replace_spaces_with_blanks:
|
306 |
+
logging.error(
|
307 |
+
"Blanks are inserted between words, and also the transition cost of blank"
|
308 |
+
" is zero. This configuration may lead to misalignments!"
|
309 |
+
)
|
310 |
+
if args.scoring_length is not None:
|
311 |
+
config.score_min_mean_over_L = args.scoring_length
|
312 |
+
logging.info(
|
313 |
+
f"Frame timings: {config.frame_duration_ms}ms * {config.subsampling_factor}"
|
314 |
+
)
|
315 |
+
# Iterate over audio files to decode and align
|
316 |
+
for idx, name in enumerate(js.keys(), 1):
|
317 |
+
logging.info("(%d/%d) Aligning " + name, idx, len(js.keys()))
|
318 |
+
batch = [(name, js[name])]
|
319 |
+
feat, label = load_inputs_and_targets(batch)
|
320 |
+
feat = feat[0]
|
321 |
+
with torch.no_grad():
|
322 |
+
# Encode input frames
|
323 |
+
enc_output = model.encode(torch.as_tensor(feat).to(device)).unsqueeze(0)
|
324 |
+
# Apply ctc layer to obtain log character probabilities
|
325 |
+
lpz = model.ctc.log_softmax(enc_output)[0].cpu().numpy()
|
326 |
+
# Prepare the text for aligning
|
327 |
+
ground_truth_mat, utt_begin_indices = prepare_text(config, text[name])
|
328 |
+
# Align using CTC segmentation
|
329 |
+
timings, char_probs, state_list = ctc_segmentation(
|
330 |
+
config, lpz, ground_truth_mat
|
331 |
+
)
|
332 |
+
logging.debug(f"state_list = {state_list}")
|
333 |
+
# Obtain list of utterances with time intervals and confidence score
|
334 |
+
segments = determine_utterance_segments(
|
335 |
+
config, utt_begin_indices, char_probs, timings, text[name]
|
336 |
+
)
|
337 |
+
# Write to "segments" file
|
338 |
+
for i, boundary in enumerate(segments):
|
339 |
+
utt_segment = (
|
340 |
+
f"{segment_names[name][i]} {name} {boundary[0]:.2f}"
|
341 |
+
f" {boundary[1]:.2f} {boundary[2]:.9f}\n"
|
342 |
+
)
|
343 |
+
args.output.write(utt_segment)
|
344 |
+
return 0
|
345 |
+
|
346 |
+
|
347 |
+
if __name__ == "__main__":
|
348 |
+
main(sys.argv[1:])
|
espnet/bin/asr_enhance.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import configargparse
|
3 |
+
from distutils.util import strtobool
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import sys
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from espnet.asr.pytorch_backend.asr import enhance
|
12 |
+
|
13 |
+
|
14 |
+
# NOTE: you need this func to generate our sphinx doc
|
15 |
+
def get_parser():
|
16 |
+
parser = configargparse.ArgumentParser(
|
17 |
+
description="Enhance noisy speech for speech recognition",
|
18 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
19 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
20 |
+
)
|
21 |
+
# general configuration
|
22 |
+
parser.add("--config", is_config_file=True, help="config file path")
|
23 |
+
parser.add(
|
24 |
+
"--config2",
|
25 |
+
is_config_file=True,
|
26 |
+
help="second config file path that overwrites the settings in `--config`.",
|
27 |
+
)
|
28 |
+
parser.add(
|
29 |
+
"--config3",
|
30 |
+
is_config_file=True,
|
31 |
+
help="third config file path that overwrites the settings "
|
32 |
+
"in `--config` and `--config2`.",
|
33 |
+
)
|
34 |
+
|
35 |
+
parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs")
|
36 |
+
parser.add_argument(
|
37 |
+
"--backend",
|
38 |
+
default="chainer",
|
39 |
+
type=str,
|
40 |
+
choices=["chainer", "pytorch"],
|
41 |
+
help="Backend library",
|
42 |
+
)
|
43 |
+
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
|
44 |
+
parser.add_argument("--seed", default=1, type=int, help="Random seed")
|
45 |
+
parser.add_argument("--verbose", "-V", default=1, type=int, help="Verbose option")
|
46 |
+
parser.add_argument(
|
47 |
+
"--batchsize",
|
48 |
+
default=1,
|
49 |
+
type=int,
|
50 |
+
help="Batch size for beam search (0: means no batch processing)",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--preprocess-conf",
|
54 |
+
type=str,
|
55 |
+
default=None,
|
56 |
+
help="The configuration file for the pre-processing",
|
57 |
+
)
|
58 |
+
# task related
|
59 |
+
parser.add_argument(
|
60 |
+
"--recog-json", type=str, help="Filename of recognition data (json)"
|
61 |
+
)
|
62 |
+
# model (parameter) related
|
63 |
+
parser.add_argument(
|
64 |
+
"--model", type=str, required=True, help="Model file parameters to read"
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--model-conf", type=str, default=None, help="Model config file"
|
68 |
+
)
|
69 |
+
|
70 |
+
# Outputs configuration
|
71 |
+
parser.add_argument(
|
72 |
+
"--enh-wspecifier",
|
73 |
+
type=str,
|
74 |
+
default=None,
|
75 |
+
help="Specify the output way for enhanced speech."
|
76 |
+
"e.g. ark,scp:outdir,wav.scp",
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--enh-filetype",
|
80 |
+
type=str,
|
81 |
+
default="sound",
|
82 |
+
choices=["mat", "hdf5", "sound.hdf5", "sound"],
|
83 |
+
help="Specify the file format for enhanced speech. "
|
84 |
+
'"mat" is the matrix format in kaldi',
|
85 |
+
)
|
86 |
+
parser.add_argument("--fs", type=int, default=16000, help="The sample frequency")
|
87 |
+
parser.add_argument(
|
88 |
+
"--keep-length",
|
89 |
+
type=strtobool,
|
90 |
+
default=True,
|
91 |
+
help="Adjust the output length to match " "with the input for enhanced speech",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--image-dir", type=str, default=None, help="The directory saving the images."
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--num-images",
|
98 |
+
type=int,
|
99 |
+
default=20,
|
100 |
+
help="The number of images files to be saved. "
|
101 |
+
"If negative, all samples are to be saved.",
|
102 |
+
)
|
103 |
+
|
104 |
+
# IStft
|
105 |
+
parser.add_argument(
|
106 |
+
"--apply-istft",
|
107 |
+
type=strtobool,
|
108 |
+
default=True,
|
109 |
+
help="Apply istft to the output from the network",
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--istft-win-length",
|
113 |
+
type=int,
|
114 |
+
default=512,
|
115 |
+
help="The window length for istft. "
|
116 |
+
"This option is ignored "
|
117 |
+
"if stft is found in the preprocess-conf",
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--istft-n-shift",
|
121 |
+
type=str,
|
122 |
+
default=256,
|
123 |
+
help="The window type for istft. "
|
124 |
+
"This option is ignored "
|
125 |
+
"if stft is found in the preprocess-conf",
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--istft-window",
|
129 |
+
type=str,
|
130 |
+
default="hann",
|
131 |
+
help="The window type for istft. "
|
132 |
+
"This option is ignored "
|
133 |
+
"if stft is found in the preprocess-conf",
|
134 |
+
)
|
135 |
+
return parser
|
136 |
+
|
137 |
+
|
138 |
+
def main(args):
|
139 |
+
parser = get_parser()
|
140 |
+
args = parser.parse_args(args)
|
141 |
+
|
142 |
+
# logging info
|
143 |
+
if args.verbose == 1:
|
144 |
+
logging.basicConfig(
|
145 |
+
level=logging.INFO,
|
146 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
147 |
+
)
|
148 |
+
elif args.verbose == 2:
|
149 |
+
logging.basicConfig(
|
150 |
+
level=logging.DEBUG,
|
151 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
logging.basicConfig(
|
155 |
+
level=logging.WARN,
|
156 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
157 |
+
)
|
158 |
+
logging.warning("Skip DEBUG/INFO messages")
|
159 |
+
|
160 |
+
# check CUDA_VISIBLE_DEVICES
|
161 |
+
if args.ngpu > 0:
|
162 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
163 |
+
if cvd is None:
|
164 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
165 |
+
elif args.ngpu != len(cvd.split(",")):
|
166 |
+
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
|
167 |
+
sys.exit(1)
|
168 |
+
|
169 |
+
# TODO(kamo): support of multiple GPUs
|
170 |
+
if args.ngpu > 1:
|
171 |
+
logging.error("The program only supports ngpu=1.")
|
172 |
+
sys.exit(1)
|
173 |
+
|
174 |
+
# display PYTHONPATH
|
175 |
+
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
|
176 |
+
|
177 |
+
# seed setting
|
178 |
+
random.seed(args.seed)
|
179 |
+
np.random.seed(args.seed)
|
180 |
+
logging.info("set random seed = %d" % args.seed)
|
181 |
+
|
182 |
+
# recog
|
183 |
+
logging.info("backend = " + args.backend)
|
184 |
+
if args.backend == "pytorch":
|
185 |
+
enhance(args)
|
186 |
+
else:
|
187 |
+
raise ValueError("Only pytorch is supported.")
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
main(sys.argv[1:])
|
espnet/bin/asr_recog.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# encoding: utf-8
|
3 |
+
|
4 |
+
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""End-to-end speech recognition model decoding script."""
|
8 |
+
|
9 |
+
import configargparse
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import random
|
13 |
+
import sys
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
from espnet.utils.cli_utils import strtobool
|
18 |
+
|
19 |
+
# NOTE: you need this func to generate our sphinx doc
|
20 |
+
|
21 |
+
|
22 |
+
def get_parser():
|
23 |
+
"""Get default arguments."""
|
24 |
+
parser = configargparse.ArgumentParser(
|
25 |
+
description="Transcribe text from speech using "
|
26 |
+
"a speech recognition model on one CPU or GPU",
|
27 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
28 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
29 |
+
)
|
30 |
+
# general configuration
|
31 |
+
parser.add("--config", is_config_file=True, help="Config file path")
|
32 |
+
parser.add(
|
33 |
+
"--config2",
|
34 |
+
is_config_file=True,
|
35 |
+
help="Second config file path that overwrites the settings in `--config`",
|
36 |
+
)
|
37 |
+
parser.add(
|
38 |
+
"--config3",
|
39 |
+
is_config_file=True,
|
40 |
+
help="Third config file path that overwrites the settings "
|
41 |
+
"in `--config` and `--config2`",
|
42 |
+
)
|
43 |
+
|
44 |
+
parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
|
45 |
+
parser.add_argument(
|
46 |
+
"--dtype",
|
47 |
+
choices=("float16", "float32", "float64"),
|
48 |
+
default="float32",
|
49 |
+
help="Float precision (only available in --api v2)",
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--backend",
|
53 |
+
type=str,
|
54 |
+
default="chainer",
|
55 |
+
choices=["chainer", "pytorch"],
|
56 |
+
help="Backend library",
|
57 |
+
)
|
58 |
+
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
|
59 |
+
parser.add_argument("--seed", type=int, default=1, help="Random seed")
|
60 |
+
parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
|
61 |
+
parser.add_argument(
|
62 |
+
"--batchsize",
|
63 |
+
type=int,
|
64 |
+
default=1,
|
65 |
+
help="Batch size for beam search (0: means no batch processing)",
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--preprocess-conf",
|
69 |
+
type=str,
|
70 |
+
default=None,
|
71 |
+
help="The configuration file for the pre-processing",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--api",
|
75 |
+
default="v1",
|
76 |
+
choices=["v1", "v2"],
|
77 |
+
help="Beam search APIs "
|
78 |
+
"v1: Default API. It only supports the ASRInterface.recognize method "
|
79 |
+
"and DefaultRNNLM. "
|
80 |
+
"v2: Experimental API. It supports any models that implements ScorerInterface.",
|
81 |
+
)
|
82 |
+
# task related
|
83 |
+
parser.add_argument(
|
84 |
+
"--recog-json", type=str, help="Filename of recognition data (json)"
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--result-label",
|
88 |
+
type=str,
|
89 |
+
required=True,
|
90 |
+
help="Filename of result label data (json)",
|
91 |
+
)
|
92 |
+
# model (parameter) related
|
93 |
+
parser.add_argument(
|
94 |
+
"--model", type=str, required=True, help="Model file parameters to read"
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--model-conf", type=str, default=None, help="Model config file"
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--num-spkrs",
|
101 |
+
type=int,
|
102 |
+
default=1,
|
103 |
+
choices=[1, 2],
|
104 |
+
help="Number of speakers in the speech",
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--num-encs", default=1, type=int, help="Number of encoders in the model."
|
108 |
+
)
|
109 |
+
# search related
|
110 |
+
parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
|
111 |
+
parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
|
112 |
+
parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty")
|
113 |
+
parser.add_argument(
|
114 |
+
"--maxlenratio",
|
115 |
+
type=float,
|
116 |
+
default=0.0,
|
117 |
+
help="""Input length ratio to obtain max output length.
|
118 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
119 |
+
to automatically find maximum hypothesis lengths""",
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--minlenratio",
|
123 |
+
type=float,
|
124 |
+
default=0.0,
|
125 |
+
help="Input length ratio to obtain min output length",
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--ctc-weight", type=float, default=0.0, help="CTC weight in joint decoding"
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--weights-ctc-dec",
|
132 |
+
type=float,
|
133 |
+
action="append",
|
134 |
+
help="ctc weight assigned to each encoder during decoding."
|
135 |
+
"[in multi-encoder mode only]",
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--ctc-window-margin",
|
139 |
+
type=int,
|
140 |
+
default=0,
|
141 |
+
help="""Use CTC window with margin parameter to accelerate
|
142 |
+
CTC/attention decoding especially on GPU. Smaller magin
|
143 |
+
makes decoding faster, but may increase search errors.
|
144 |
+
If margin=0 (default), this function is disabled""",
|
145 |
+
)
|
146 |
+
# transducer related
|
147 |
+
parser.add_argument(
|
148 |
+
"--search-type",
|
149 |
+
type=str,
|
150 |
+
default="default",
|
151 |
+
choices=["default", "nsc", "tsd", "alsd"],
|
152 |
+
help="""Type of beam search implementation to use during inference.
|
153 |
+
Can be either: default beam search, n-step constrained beam search ("nsc"),
|
154 |
+
time-synchronous decoding ("tsd") or alignment-length synchronous decoding
|
155 |
+
("alsd").
|
156 |
+
Additional associated parameters: "nstep" + "prefix-alpha" (for nsc),
|
157 |
+
"max-sym-exp" (for tsd) and "u-max" (for alsd)""",
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
"--nstep",
|
161 |
+
type=int,
|
162 |
+
default=1,
|
163 |
+
help="Number of expansion steps allowed in NSC beam search.",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--prefix-alpha",
|
167 |
+
type=int,
|
168 |
+
default=2,
|
169 |
+
help="Length prefix difference allowed in NSC beam search.",
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
"--max-sym-exp",
|
173 |
+
type=int,
|
174 |
+
default=2,
|
175 |
+
help="Number of symbol expansions allowed in TSD decoding.",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--u-max",
|
179 |
+
type=int,
|
180 |
+
default=400,
|
181 |
+
help="Length prefix difference allowed in ALSD beam search.",
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--score-norm",
|
185 |
+
type=strtobool,
|
186 |
+
nargs="?",
|
187 |
+
default=True,
|
188 |
+
help="Normalize transducer scores by length",
|
189 |
+
)
|
190 |
+
# rnnlm related
|
191 |
+
parser.add_argument(
|
192 |
+
"--rnnlm", type=str, default=None, help="RNNLM model file to read"
|
193 |
+
)
|
194 |
+
parser.add_argument(
|
195 |
+
"--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read"
|
196 |
+
)
|
197 |
+
parser.add_argument(
|
198 |
+
"--word-rnnlm", type=str, default=None, help="Word RNNLM model file to read"
|
199 |
+
)
|
200 |
+
parser.add_argument(
|
201 |
+
"--word-rnnlm-conf",
|
202 |
+
type=str,
|
203 |
+
default=None,
|
204 |
+
help="Word RNNLM model config file to read",
|
205 |
+
)
|
206 |
+
parser.add_argument("--word-dict", type=str, default=None, help="Word list to read")
|
207 |
+
parser.add_argument("--lm-weight", type=float, default=0.1, help="RNNLM weight")
|
208 |
+
# ngram related
|
209 |
+
parser.add_argument(
|
210 |
+
"--ngram-model", type=str, default=None, help="ngram model file to read"
|
211 |
+
)
|
212 |
+
parser.add_argument("--ngram-weight", type=float, default=0.1, help="ngram weight")
|
213 |
+
parser.add_argument(
|
214 |
+
"--ngram-scorer",
|
215 |
+
type=str,
|
216 |
+
default="part",
|
217 |
+
choices=("full", "part"),
|
218 |
+
help="""if the ngram is set as a part scorer, similar with CTC scorer,
|
219 |
+
ngram scorer only scores topK hypethesis.
|
220 |
+
if the ngram is set as full scorer, ngram scorer scores all hypthesis
|
221 |
+
the decoding speed of part scorer is musch faster than full one""",
|
222 |
+
)
|
223 |
+
# streaming related
|
224 |
+
parser.add_argument(
|
225 |
+
"--streaming-mode",
|
226 |
+
type=str,
|
227 |
+
default=None,
|
228 |
+
choices=["window", "segment"],
|
229 |
+
help="""Use streaming recognizer for inference.
|
230 |
+
`--batchsize` must be set to 0 to enable this mode""",
|
231 |
+
)
|
232 |
+
parser.add_argument("--streaming-window", type=int, default=10, help="Window size")
|
233 |
+
parser.add_argument(
|
234 |
+
"--streaming-min-blank-dur",
|
235 |
+
type=int,
|
236 |
+
default=10,
|
237 |
+
help="Minimum blank duration threshold",
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--streaming-onset-margin", type=int, default=1, help="Onset margin"
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--streaming-offset-margin", type=int, default=1, help="Offset margin"
|
244 |
+
)
|
245 |
+
# non-autoregressive related
|
246 |
+
# Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail.
|
247 |
+
parser.add_argument(
|
248 |
+
"--maskctc-n-iterations",
|
249 |
+
type=int,
|
250 |
+
default=10,
|
251 |
+
help="Number of decoding iterations."
|
252 |
+
"For Mask CTC, set 0 to predict 1 mask/iter.",
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--maskctc-probability-threshold",
|
256 |
+
type=float,
|
257 |
+
default=0.999,
|
258 |
+
help="Threshold probability for CTC output",
|
259 |
+
)
|
260 |
+
|
261 |
+
return parser
|
262 |
+
|
263 |
+
|
264 |
+
def main(args):
|
265 |
+
"""Run the main decoding function."""
|
266 |
+
parser = get_parser()
|
267 |
+
args = parser.parse_args(args)
|
268 |
+
|
269 |
+
if args.ngpu == 0 and args.dtype == "float16":
|
270 |
+
raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.")
|
271 |
+
|
272 |
+
# logging info
|
273 |
+
if args.verbose == 1:
|
274 |
+
logging.basicConfig(
|
275 |
+
level=logging.INFO,
|
276 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
277 |
+
)
|
278 |
+
elif args.verbose == 2:
|
279 |
+
logging.basicConfig(
|
280 |
+
level=logging.DEBUG,
|
281 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
logging.basicConfig(
|
285 |
+
level=logging.WARN,
|
286 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
287 |
+
)
|
288 |
+
logging.warning("Skip DEBUG/INFO messages")
|
289 |
+
|
290 |
+
# check CUDA_VISIBLE_DEVICES
|
291 |
+
if args.ngpu > 0:
|
292 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
293 |
+
if cvd is None:
|
294 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
295 |
+
elif args.ngpu != len(cvd.split(",")):
|
296 |
+
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
|
297 |
+
sys.exit(1)
|
298 |
+
|
299 |
+
# TODO(mn5k): support of multiple GPUs
|
300 |
+
if args.ngpu > 1:
|
301 |
+
logging.error("The program only supports ngpu=1.")
|
302 |
+
sys.exit(1)
|
303 |
+
|
304 |
+
# display PYTHONPATH
|
305 |
+
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
|
306 |
+
|
307 |
+
# seed setting
|
308 |
+
random.seed(args.seed)
|
309 |
+
np.random.seed(args.seed)
|
310 |
+
logging.info("set random seed = %d" % args.seed)
|
311 |
+
|
312 |
+
# validate rnn options
|
313 |
+
if args.rnnlm is not None and args.word_rnnlm is not None:
|
314 |
+
logging.error(
|
315 |
+
"It seems that both --rnnlm and --word-rnnlm are specified. "
|
316 |
+
"Please use either option."
|
317 |
+
)
|
318 |
+
sys.exit(1)
|
319 |
+
|
320 |
+
# recog
|
321 |
+
logging.info("backend = " + args.backend)
|
322 |
+
if args.num_spkrs == 1:
|
323 |
+
if args.backend == "chainer":
|
324 |
+
from espnet.asr.chainer_backend.asr import recog
|
325 |
+
|
326 |
+
recog(args)
|
327 |
+
elif args.backend == "pytorch":
|
328 |
+
if args.num_encs == 1:
|
329 |
+
# Experimental API that supports custom LMs
|
330 |
+
if args.api == "v2":
|
331 |
+
from espnet.asr.pytorch_backend.recog import recog_v2
|
332 |
+
|
333 |
+
recog_v2(args)
|
334 |
+
else:
|
335 |
+
from espnet.asr.pytorch_backend.asr import recog
|
336 |
+
|
337 |
+
if args.dtype != "float32":
|
338 |
+
raise NotImplementedError(
|
339 |
+
f"`--dtype {args.dtype}` is only available with `--api v2`"
|
340 |
+
)
|
341 |
+
recog(args)
|
342 |
+
else:
|
343 |
+
if args.api == "v2":
|
344 |
+
raise NotImplementedError(
|
345 |
+
f"--num-encs {args.num_encs} > 1 is not supported in --api v2"
|
346 |
+
)
|
347 |
+
else:
|
348 |
+
from espnet.asr.pytorch_backend.asr import recog
|
349 |
+
|
350 |
+
recog(args)
|
351 |
+
else:
|
352 |
+
raise ValueError("Only chainer and pytorch are supported.")
|
353 |
+
elif args.num_spkrs == 2:
|
354 |
+
if args.backend == "pytorch":
|
355 |
+
from espnet.asr.pytorch_backend.asr_mix import recog
|
356 |
+
|
357 |
+
recog(args)
|
358 |
+
else:
|
359 |
+
raise ValueError("Only pytorch is supported.")
|
360 |
+
|
361 |
+
|
362 |
+
if __name__ == "__main__":
|
363 |
+
main(sys.argv[1:])
|
espnet/bin/asr_train.py
ADDED
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# encoding: utf-8
|
3 |
+
|
4 |
+
# Copyright 2017 Tomoki Hayashi (Nagoya University)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Automatic speech recognition model training script."""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import random
|
12 |
+
import subprocess
|
13 |
+
import sys
|
14 |
+
|
15 |
+
from distutils.version import LooseVersion
|
16 |
+
|
17 |
+
import configargparse
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from espnet import __version__
|
22 |
+
from espnet.utils.cli_utils import strtobool
|
23 |
+
from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES
|
24 |
+
|
25 |
+
is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2")
|
26 |
+
|
27 |
+
|
28 |
+
# NOTE: you need this func to generate our sphinx doc
|
29 |
+
def get_parser(parser=None, required=True):
|
30 |
+
"""Get default arguments."""
|
31 |
+
if parser is None:
|
32 |
+
parser = configargparse.ArgumentParser(
|
33 |
+
description="Train an automatic speech recognition (ASR) model on one CPU, "
|
34 |
+
"one or multiple GPUs",
|
35 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
36 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
37 |
+
)
|
38 |
+
# general configuration
|
39 |
+
parser.add("--config", is_config_file=True, help="config file path")
|
40 |
+
parser.add(
|
41 |
+
"--config2",
|
42 |
+
is_config_file=True,
|
43 |
+
help="second config file path that overwrites the settings in `--config`.",
|
44 |
+
)
|
45 |
+
parser.add(
|
46 |
+
"--config3",
|
47 |
+
is_config_file=True,
|
48 |
+
help="third config file path that overwrites the settings in "
|
49 |
+
"`--config` and `--config2`.",
|
50 |
+
)
|
51 |
+
|
52 |
+
parser.add_argument(
|
53 |
+
"--ngpu",
|
54 |
+
default=None,
|
55 |
+
type=int,
|
56 |
+
help="Number of GPUs. If not given, use all visible devices",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--train-dtype",
|
60 |
+
default="float32",
|
61 |
+
choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"],
|
62 |
+
help="Data type for training (only pytorch backend). "
|
63 |
+
"O0,O1,.. flags require apex. "
|
64 |
+
"See https://nvidia.github.io/apex/amp.html#opt-levels",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--backend",
|
68 |
+
default="chainer",
|
69 |
+
type=str,
|
70 |
+
choices=["chainer", "pytorch"],
|
71 |
+
help="Backend library",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--outdir", type=str, required=required, help="Output directory"
|
75 |
+
)
|
76 |
+
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
|
77 |
+
parser.add_argument("--dict", required=required, help="Dictionary")
|
78 |
+
parser.add_argument("--seed", default=1, type=int, help="Random seed")
|
79 |
+
parser.add_argument("--debugdir", type=str, help="Output directory for debugging")
|
80 |
+
parser.add_argument(
|
81 |
+
"--resume",
|
82 |
+
"-r",
|
83 |
+
default="",
|
84 |
+
nargs="?",
|
85 |
+
help="Resume the training from snapshot",
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--minibatches",
|
89 |
+
"-N",
|
90 |
+
type=int,
|
91 |
+
default="-1",
|
92 |
+
help="Process only N minibatches (for debug)",
|
93 |
+
)
|
94 |
+
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
|
95 |
+
parser.add_argument(
|
96 |
+
"--tensorboard-dir",
|
97 |
+
default=None,
|
98 |
+
type=str,
|
99 |
+
nargs="?",
|
100 |
+
help="Tensorboard log dir path",
|
101 |
+
)
|
102 |
+
parser.add_argument(
|
103 |
+
"--report-interval-iters",
|
104 |
+
default=100,
|
105 |
+
type=int,
|
106 |
+
help="Report interval iterations",
|
107 |
+
)
|
108 |
+
parser.add_argument(
|
109 |
+
"--save-interval-iters",
|
110 |
+
default=0,
|
111 |
+
type=int,
|
112 |
+
help="Save snapshot interval iterations",
|
113 |
+
)
|
114 |
+
# task related
|
115 |
+
parser.add_argument(
|
116 |
+
"--train-json",
|
117 |
+
type=str,
|
118 |
+
default=None,
|
119 |
+
help="Filename of train label data (json)",
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--valid-json",
|
123 |
+
type=str,
|
124 |
+
default=None,
|
125 |
+
help="Filename of validation label data (json)",
|
126 |
+
)
|
127 |
+
# network architecture
|
128 |
+
parser.add_argument(
|
129 |
+
"--model-module",
|
130 |
+
type=str,
|
131 |
+
default=None,
|
132 |
+
help="model defined module (default: espnet.nets.xxx_backend.e2e_asr:E2E)",
|
133 |
+
)
|
134 |
+
# encoder
|
135 |
+
parser.add_argument(
|
136 |
+
"--num-encs", default=1, type=int, help="Number of encoders in the model."
|
137 |
+
)
|
138 |
+
# loss related
|
139 |
+
parser.add_argument(
|
140 |
+
"--ctc_type",
|
141 |
+
default="warpctc",
|
142 |
+
type=str,
|
143 |
+
choices=["builtin", "warpctc", "gtnctc", "cudnnctc"],
|
144 |
+
help="Type of CTC implementation to calculate loss.",
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"--mtlalpha",
|
148 |
+
default=0.5,
|
149 |
+
type=float,
|
150 |
+
help="Multitask learning coefficient, "
|
151 |
+
"alpha: alpha*ctc_loss + (1-alpha)*att_loss ",
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--lsm-weight", default=0.0, type=float, help="Label smoothing weight"
|
155 |
+
)
|
156 |
+
# recognition options to compute CER/WER
|
157 |
+
parser.add_argument(
|
158 |
+
"--report-cer",
|
159 |
+
default=False,
|
160 |
+
action="store_true",
|
161 |
+
help="Compute CER on development set",
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
"--report-wer",
|
165 |
+
default=False,
|
166 |
+
action="store_true",
|
167 |
+
help="Compute WER on development set",
|
168 |
+
)
|
169 |
+
parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
|
170 |
+
parser.add_argument("--beam-size", type=int, default=4, help="Beam size")
|
171 |
+
parser.add_argument("--penalty", default=0.0, type=float, help="Incertion penalty")
|
172 |
+
parser.add_argument(
|
173 |
+
"--maxlenratio",
|
174 |
+
default=0.0,
|
175 |
+
type=float,
|
176 |
+
help="""Input length ratio to obtain max output length.
|
177 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
178 |
+
to automatically find maximum hypothesis lengths""",
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--minlenratio",
|
182 |
+
default=0.0,
|
183 |
+
type=float,
|
184 |
+
help="Input length ratio to obtain min output length",
|
185 |
+
)
|
186 |
+
parser.add_argument(
|
187 |
+
"--ctc-weight", default=0.3, type=float, help="CTC weight in joint decoding"
|
188 |
+
)
|
189 |
+
parser.add_argument(
|
190 |
+
"--rnnlm", type=str, default=None, help="RNNLM model file to read"
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read"
|
194 |
+
)
|
195 |
+
parser.add_argument("--lm-weight", default=0.1, type=float, help="RNNLM weight.")
|
196 |
+
parser.add_argument("--sym-space", default="<space>", type=str, help="Space symbol")
|
197 |
+
parser.add_argument("--sym-blank", default="<blank>", type=str, help="Blank symbol")
|
198 |
+
# minibatch related
|
199 |
+
parser.add_argument(
|
200 |
+
"--sortagrad",
|
201 |
+
default=0,
|
202 |
+
type=int,
|
203 |
+
nargs="?",
|
204 |
+
help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
|
205 |
+
)
|
206 |
+
parser.add_argument(
|
207 |
+
"--batch-count",
|
208 |
+
default="auto",
|
209 |
+
choices=BATCH_COUNT_CHOICES,
|
210 |
+
help="How to count batch_size. "
|
211 |
+
"The default (auto) will find how to count by args.",
|
212 |
+
)
|
213 |
+
parser.add_argument(
|
214 |
+
"--batch-size",
|
215 |
+
"--batch-seqs",
|
216 |
+
"-b",
|
217 |
+
default=0,
|
218 |
+
type=int,
|
219 |
+
help="Maximum seqs in a minibatch (0 to disable)",
|
220 |
+
)
|
221 |
+
parser.add_argument(
|
222 |
+
"--batch-bins",
|
223 |
+
default=0,
|
224 |
+
type=int,
|
225 |
+
help="Maximum bins in a minibatch (0 to disable)",
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
"--batch-frames-in",
|
229 |
+
default=0,
|
230 |
+
type=int,
|
231 |
+
help="Maximum input frames in a minibatch (0 to disable)",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--batch-frames-out",
|
235 |
+
default=0,
|
236 |
+
type=int,
|
237 |
+
help="Maximum output frames in a minibatch (0 to disable)",
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--batch-frames-inout",
|
241 |
+
default=0,
|
242 |
+
type=int,
|
243 |
+
help="Maximum input+output frames in a minibatch (0 to disable)",
|
244 |
+
)
|
245 |
+
parser.add_argument(
|
246 |
+
"--maxlen-in",
|
247 |
+
"--batch-seq-maxlen-in",
|
248 |
+
default=800,
|
249 |
+
type=int,
|
250 |
+
metavar="ML",
|
251 |
+
help="When --batch-count=seq, "
|
252 |
+
"batch size is reduced if the input sequence length > ML.",
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--maxlen-out",
|
256 |
+
"--batch-seq-maxlen-out",
|
257 |
+
default=150,
|
258 |
+
type=int,
|
259 |
+
metavar="ML",
|
260 |
+
help="When --batch-count=seq, "
|
261 |
+
"batch size is reduced if the output sequence length > ML",
|
262 |
+
)
|
263 |
+
parser.add_argument(
|
264 |
+
"--n-iter-processes",
|
265 |
+
default=0,
|
266 |
+
type=int,
|
267 |
+
help="Number of processes of iterator",
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--preprocess-conf",
|
271 |
+
type=str,
|
272 |
+
default=None,
|
273 |
+
nargs="?",
|
274 |
+
help="The configuration file for the pre-processing",
|
275 |
+
)
|
276 |
+
# optimization related
|
277 |
+
parser.add_argument(
|
278 |
+
"--opt",
|
279 |
+
default="adadelta",
|
280 |
+
type=str,
|
281 |
+
choices=["adadelta", "adam", "noam"],
|
282 |
+
help="Optimizer",
|
283 |
+
)
|
284 |
+
parser.add_argument(
|
285 |
+
"--accum-grad", default=1, type=int, help="Number of gradient accumuration"
|
286 |
+
)
|
287 |
+
parser.add_argument(
|
288 |
+
"--eps", default=1e-8, type=float, help="Epsilon constant for optimizer"
|
289 |
+
)
|
290 |
+
parser.add_argument(
|
291 |
+
"--eps-decay", default=0.01, type=float, help="Decaying ratio of epsilon"
|
292 |
+
)
|
293 |
+
parser.add_argument(
|
294 |
+
"--weight-decay", default=0.0, type=float, help="Weight decay ratio"
|
295 |
+
)
|
296 |
+
parser.add_argument(
|
297 |
+
"--criterion",
|
298 |
+
default="acc",
|
299 |
+
type=str,
|
300 |
+
choices=["loss", "loss_eps_decay_only", "acc"],
|
301 |
+
help="Criterion to perform epsilon decay",
|
302 |
+
)
|
303 |
+
parser.add_argument(
|
304 |
+
"--threshold", default=1e-4, type=float, help="Threshold to stop iteration"
|
305 |
+
)
|
306 |
+
parser.add_argument(
|
307 |
+
"--epochs", "-e", default=30, type=int, help="Maximum number of epochs"
|
308 |
+
)
|
309 |
+
parser.add_argument(
|
310 |
+
"--early-stop-criterion",
|
311 |
+
default="validation/main/acc",
|
312 |
+
type=str,
|
313 |
+
nargs="?",
|
314 |
+
help="Value to monitor to trigger an early stopping of the training",
|
315 |
+
)
|
316 |
+
parser.add_argument(
|
317 |
+
"--patience",
|
318 |
+
default=3,
|
319 |
+
type=int,
|
320 |
+
nargs="?",
|
321 |
+
help="Number of epochs to wait without improvement "
|
322 |
+
"before stopping the training",
|
323 |
+
)
|
324 |
+
parser.add_argument(
|
325 |
+
"--grad-clip", default=5, type=float, help="Gradient norm threshold to clip"
|
326 |
+
)
|
327 |
+
parser.add_argument(
|
328 |
+
"--num-save-attention",
|
329 |
+
default=3,
|
330 |
+
type=int,
|
331 |
+
help="Number of samples of attention to be saved",
|
332 |
+
)
|
333 |
+
parser.add_argument(
|
334 |
+
"--num-save-ctc",
|
335 |
+
default=3,
|
336 |
+
type=int,
|
337 |
+
help="Number of samples of CTC probability to be saved",
|
338 |
+
)
|
339 |
+
parser.add_argument(
|
340 |
+
"--grad-noise",
|
341 |
+
type=strtobool,
|
342 |
+
default=False,
|
343 |
+
help="The flag to switch to use noise injection to gradients during training",
|
344 |
+
)
|
345 |
+
# asr_mix related
|
346 |
+
parser.add_argument(
|
347 |
+
"--num-spkrs",
|
348 |
+
default=1,
|
349 |
+
type=int,
|
350 |
+
choices=[1, 2],
|
351 |
+
help="Number of speakers in the speech.",
|
352 |
+
)
|
353 |
+
# decoder related
|
354 |
+
parser.add_argument(
|
355 |
+
"--context-residual",
|
356 |
+
default=False,
|
357 |
+
type=strtobool,
|
358 |
+
nargs="?",
|
359 |
+
help="The flag to switch to use context vector residual in the decoder network",
|
360 |
+
)
|
361 |
+
# finetuning related
|
362 |
+
parser.add_argument(
|
363 |
+
"--enc-init",
|
364 |
+
default=None,
|
365 |
+
type=str,
|
366 |
+
help="Pre-trained ASR model to initialize encoder.",
|
367 |
+
)
|
368 |
+
parser.add_argument(
|
369 |
+
"--enc-init-mods",
|
370 |
+
default="enc.enc.",
|
371 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
372 |
+
help="List of encoder modules to initialize, separated by a comma.",
|
373 |
+
)
|
374 |
+
parser.add_argument(
|
375 |
+
"--dec-init",
|
376 |
+
default=None,
|
377 |
+
type=str,
|
378 |
+
help="Pre-trained ASR, MT or LM model to initialize decoder.",
|
379 |
+
)
|
380 |
+
parser.add_argument(
|
381 |
+
"--dec-init-mods",
|
382 |
+
default="att.,dec.",
|
383 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
384 |
+
help="List of decoder modules to initialize, separated by a comma.",
|
385 |
+
)
|
386 |
+
parser.add_argument(
|
387 |
+
"--freeze-mods",
|
388 |
+
default=None,
|
389 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
390 |
+
help="List of modules to freeze, separated by a comma.",
|
391 |
+
)
|
392 |
+
# front end related
|
393 |
+
parser.add_argument(
|
394 |
+
"--use-frontend",
|
395 |
+
type=strtobool,
|
396 |
+
default=False,
|
397 |
+
help="The flag to switch to use frontend system.",
|
398 |
+
)
|
399 |
+
|
400 |
+
# WPE related
|
401 |
+
parser.add_argument(
|
402 |
+
"--use-wpe",
|
403 |
+
type=strtobool,
|
404 |
+
default=False,
|
405 |
+
help="Apply Weighted Prediction Error",
|
406 |
+
)
|
407 |
+
parser.add_argument(
|
408 |
+
"--wtype",
|
409 |
+
default="blstmp",
|
410 |
+
type=str,
|
411 |
+
choices=[
|
412 |
+
"lstm",
|
413 |
+
"blstm",
|
414 |
+
"lstmp",
|
415 |
+
"blstmp",
|
416 |
+
"vgglstmp",
|
417 |
+
"vggblstmp",
|
418 |
+
"vgglstm",
|
419 |
+
"vggblstm",
|
420 |
+
"gru",
|
421 |
+
"bgru",
|
422 |
+
"grup",
|
423 |
+
"bgrup",
|
424 |
+
"vgggrup",
|
425 |
+
"vggbgrup",
|
426 |
+
"vgggru",
|
427 |
+
"vggbgru",
|
428 |
+
],
|
429 |
+
help="Type of encoder network architecture "
|
430 |
+
"of the mask estimator for WPE. "
|
431 |
+
"",
|
432 |
+
)
|
433 |
+
parser.add_argument("--wlayers", type=int, default=2, help="")
|
434 |
+
parser.add_argument("--wunits", type=int, default=300, help="")
|
435 |
+
parser.add_argument("--wprojs", type=int, default=300, help="")
|
436 |
+
parser.add_argument("--wdropout-rate", type=float, default=0.0, help="")
|
437 |
+
parser.add_argument("--wpe-taps", type=int, default=5, help="")
|
438 |
+
parser.add_argument("--wpe-delay", type=int, default=3, help="")
|
439 |
+
parser.add_argument(
|
440 |
+
"--use-dnn-mask-for-wpe",
|
441 |
+
type=strtobool,
|
442 |
+
default=False,
|
443 |
+
help="Use DNN to estimate the power spectrogram. "
|
444 |
+
"This option is experimental.",
|
445 |
+
)
|
446 |
+
# Beamformer related
|
447 |
+
parser.add_argument("--use-beamformer", type=strtobool, default=True, help="")
|
448 |
+
parser.add_argument(
|
449 |
+
"--btype",
|
450 |
+
default="blstmp",
|
451 |
+
type=str,
|
452 |
+
choices=[
|
453 |
+
"lstm",
|
454 |
+
"blstm",
|
455 |
+
"lstmp",
|
456 |
+
"blstmp",
|
457 |
+
"vgglstmp",
|
458 |
+
"vggblstmp",
|
459 |
+
"vgglstm",
|
460 |
+
"vggblstm",
|
461 |
+
"gru",
|
462 |
+
"bgru",
|
463 |
+
"grup",
|
464 |
+
"bgrup",
|
465 |
+
"vgggrup",
|
466 |
+
"vggbgrup",
|
467 |
+
"vgggru",
|
468 |
+
"vggbgru",
|
469 |
+
],
|
470 |
+
help="Type of encoder network architecture "
|
471 |
+
"of the mask estimator for Beamformer.",
|
472 |
+
)
|
473 |
+
parser.add_argument("--blayers", type=int, default=2, help="")
|
474 |
+
parser.add_argument("--bunits", type=int, default=300, help="")
|
475 |
+
parser.add_argument("--bprojs", type=int, default=300, help="")
|
476 |
+
parser.add_argument("--badim", type=int, default=320, help="")
|
477 |
+
parser.add_argument(
|
478 |
+
"--bnmask",
|
479 |
+
type=int,
|
480 |
+
default=2,
|
481 |
+
help="Number of beamforming masks, " "default is 2 for [speech, noise].",
|
482 |
+
)
|
483 |
+
parser.add_argument(
|
484 |
+
"--ref-channel",
|
485 |
+
type=int,
|
486 |
+
default=-1,
|
487 |
+
help="The reference channel used for beamformer. "
|
488 |
+
"By default, the channel is estimated by DNN.",
|
489 |
+
)
|
490 |
+
parser.add_argument("--bdropout-rate", type=float, default=0.0, help="")
|
491 |
+
# Feature transform: Normalization
|
492 |
+
parser.add_argument(
|
493 |
+
"--stats-file",
|
494 |
+
type=str,
|
495 |
+
default=None,
|
496 |
+
help="The stats file for the feature normalization",
|
497 |
+
)
|
498 |
+
parser.add_argument(
|
499 |
+
"--apply-uttmvn",
|
500 |
+
type=strtobool,
|
501 |
+
default=True,
|
502 |
+
help="Apply utterance level mean " "variance normalization.",
|
503 |
+
)
|
504 |
+
parser.add_argument("--uttmvn-norm-means", type=strtobool, default=True, help="")
|
505 |
+
parser.add_argument("--uttmvn-norm-vars", type=strtobool, default=False, help="")
|
506 |
+
# Feature transform: Fbank
|
507 |
+
parser.add_argument(
|
508 |
+
"--fbank-fs",
|
509 |
+
type=int,
|
510 |
+
default=16000,
|
511 |
+
help="The sample frequency used for " "the mel-fbank creation.",
|
512 |
+
)
|
513 |
+
parser.add_argument(
|
514 |
+
"--n-mels", type=int, default=80, help="The number of mel-frequency bins."
|
515 |
+
)
|
516 |
+
parser.add_argument("--fbank-fmin", type=float, default=0.0, help="")
|
517 |
+
parser.add_argument("--fbank-fmax", type=float, default=None, help="")
|
518 |
+
return parser
|
519 |
+
|
520 |
+
|
521 |
+
def main(cmd_args):
|
522 |
+
"""Run the main training function."""
|
523 |
+
parser = get_parser()
|
524 |
+
args, _ = parser.parse_known_args(cmd_args)
|
525 |
+
if args.backend == "chainer" and args.train_dtype != "float32":
|
526 |
+
raise NotImplementedError(
|
527 |
+
f"chainer backend does not support --train-dtype {args.train_dtype}."
|
528 |
+
"Use --dtype float32."
|
529 |
+
)
|
530 |
+
if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"):
|
531 |
+
raise ValueError(
|
532 |
+
f"--train-dtype {args.train_dtype} does not support the CPU backend."
|
533 |
+
)
|
534 |
+
|
535 |
+
from espnet.utils.dynamic_import import dynamic_import
|
536 |
+
|
537 |
+
if args.model_module is None:
|
538 |
+
if args.num_spkrs == 1:
|
539 |
+
model_module = "espnet.nets." + args.backend + "_backend.e2e_asr:E2E"
|
540 |
+
else:
|
541 |
+
model_module = "espnet.nets." + args.backend + "_backend.e2e_asr_mix:E2E"
|
542 |
+
else:
|
543 |
+
model_module = args.model_module
|
544 |
+
model_class = dynamic_import(model_module)
|
545 |
+
model_class.add_arguments(parser)
|
546 |
+
|
547 |
+
args = parser.parse_args(cmd_args)
|
548 |
+
args.model_module = model_module
|
549 |
+
if "chainer_backend" in args.model_module:
|
550 |
+
args.backend = "chainer"
|
551 |
+
if "pytorch_backend" in args.model_module:
|
552 |
+
args.backend = "pytorch"
|
553 |
+
|
554 |
+
# add version info in args
|
555 |
+
args.version = __version__
|
556 |
+
|
557 |
+
# logging info
|
558 |
+
if args.verbose > 0:
|
559 |
+
logging.basicConfig(
|
560 |
+
level=logging.INFO,
|
561 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
562 |
+
)
|
563 |
+
else:
|
564 |
+
logging.basicConfig(
|
565 |
+
level=logging.WARN,
|
566 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
567 |
+
)
|
568 |
+
logging.warning("Skip DEBUG/INFO messages")
|
569 |
+
|
570 |
+
# If --ngpu is not given,
|
571 |
+
# 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
|
572 |
+
# 2. if nvidia-smi exists, use all devices
|
573 |
+
# 3. else ngpu=0
|
574 |
+
if args.ngpu is None:
|
575 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
576 |
+
if cvd is not None:
|
577 |
+
ngpu = len(cvd.split(","))
|
578 |
+
else:
|
579 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
580 |
+
try:
|
581 |
+
p = subprocess.run(
|
582 |
+
["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
583 |
+
)
|
584 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
585 |
+
ngpu = 0
|
586 |
+
else:
|
587 |
+
ngpu = len(p.stderr.decode().split("\n")) - 1
|
588 |
+
else:
|
589 |
+
if is_torch_1_2_plus and args.ngpu != 1:
|
590 |
+
logging.debug(
|
591 |
+
"There are some bugs with multi-GPU processing in PyTorch 1.2+"
|
592 |
+
+ " (see https://github.com/pytorch/pytorch/issues/21108)"
|
593 |
+
)
|
594 |
+
ngpu = args.ngpu
|
595 |
+
logging.info(f"ngpu: {ngpu}")
|
596 |
+
|
597 |
+
# display PYTHONPATH
|
598 |
+
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
|
599 |
+
|
600 |
+
# set random seed
|
601 |
+
logging.info("random seed = %d" % args.seed)
|
602 |
+
random.seed(args.seed)
|
603 |
+
np.random.seed(args.seed)
|
604 |
+
|
605 |
+
# load dictionary for debug log
|
606 |
+
if args.dict is not None:
|
607 |
+
with open(args.dict, "rb") as f:
|
608 |
+
dictionary = f.readlines()
|
609 |
+
char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary]
|
610 |
+
char_list.insert(0, "<blank>")
|
611 |
+
char_list.append("<eos>")
|
612 |
+
# for non-autoregressive maskctc model
|
613 |
+
if "maskctc" in args.model_module:
|
614 |
+
char_list.append("<mask>")
|
615 |
+
args.char_list = char_list
|
616 |
+
else:
|
617 |
+
args.char_list = None
|
618 |
+
|
619 |
+
# train
|
620 |
+
logging.info("backend = " + args.backend)
|
621 |
+
|
622 |
+
if args.num_spkrs == 1:
|
623 |
+
if args.backend == "chainer":
|
624 |
+
from espnet.asr.chainer_backend.asr import train
|
625 |
+
|
626 |
+
train(args)
|
627 |
+
elif args.backend == "pytorch":
|
628 |
+
from espnet.asr.pytorch_backend.asr import train
|
629 |
+
|
630 |
+
train(args)
|
631 |
+
else:
|
632 |
+
raise ValueError("Only chainer and pytorch are supported.")
|
633 |
+
else:
|
634 |
+
# FIXME(kamo): Support --model-module
|
635 |
+
if args.backend == "pytorch":
|
636 |
+
from espnet.asr.pytorch_backend.asr_mix import train
|
637 |
+
|
638 |
+
train(args)
|
639 |
+
else:
|
640 |
+
raise ValueError("Only pytorch is supported.")
|
641 |
+
|
642 |
+
|
643 |
+
if __name__ == "__main__":
|
644 |
+
main(sys.argv[1:])
|
espnet/bin/lm_train.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
# This code is ported from the following implementation written in Torch.
|
7 |
+
# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
|
8 |
+
|
9 |
+
"""Language model training script."""
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
import subprocess
|
15 |
+
import sys
|
16 |
+
|
17 |
+
import configargparse
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
from espnet import __version__
|
21 |
+
from espnet.nets.lm_interface import dynamic_import_lm
|
22 |
+
from espnet.optimizer.factory import dynamic_import_optimizer
|
23 |
+
from espnet.scheduler.scheduler import dynamic_import_scheduler
|
24 |
+
|
25 |
+
|
26 |
+
# NOTE: you need this func to generate our sphinx doc
|
27 |
+
def get_parser(parser=None, required=True):
|
28 |
+
"""Get parser."""
|
29 |
+
if parser is None:
|
30 |
+
parser = configargparse.ArgumentParser(
|
31 |
+
description="Train a new language model on one CPU or one GPU",
|
32 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
33 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
34 |
+
)
|
35 |
+
# general configuration
|
36 |
+
parser.add("--config", is_config_file=True, help="config file path")
|
37 |
+
parser.add(
|
38 |
+
"--config2",
|
39 |
+
is_config_file=True,
|
40 |
+
help="second config file path that overwrites the settings in `--config`.",
|
41 |
+
)
|
42 |
+
parser.add(
|
43 |
+
"--config3",
|
44 |
+
is_config_file=True,
|
45 |
+
help="third config file path that overwrites the settings "
|
46 |
+
"in `--config` and `--config2`.",
|
47 |
+
)
|
48 |
+
|
49 |
+
parser.add_argument(
|
50 |
+
"--ngpu",
|
51 |
+
default=None,
|
52 |
+
type=int,
|
53 |
+
help="Number of GPUs. If not given, use all visible devices",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--train-dtype",
|
57 |
+
default="float32",
|
58 |
+
choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"],
|
59 |
+
help="Data type for training (only pytorch backend). "
|
60 |
+
"O0,O1,.. flags require apex. "
|
61 |
+
"See https://nvidia.github.io/apex/amp.html#opt-levels",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--backend",
|
65 |
+
default="chainer",
|
66 |
+
type=str,
|
67 |
+
choices=["chainer", "pytorch"],
|
68 |
+
help="Backend library",
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--outdir", type=str, required=required, help="Output directory"
|
72 |
+
)
|
73 |
+
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
|
74 |
+
parser.add_argument("--dict", type=str, required=required, help="Dictionary")
|
75 |
+
parser.add_argument("--seed", default=1, type=int, help="Random seed")
|
76 |
+
parser.add_argument(
|
77 |
+
"--resume",
|
78 |
+
"-r",
|
79 |
+
default="",
|
80 |
+
nargs="?",
|
81 |
+
help="Resume the training from snapshot",
|
82 |
+
)
|
83 |
+
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
|
84 |
+
parser.add_argument(
|
85 |
+
"--tensorboard-dir",
|
86 |
+
default=None,
|
87 |
+
type=str,
|
88 |
+
nargs="?",
|
89 |
+
help="Tensorboard log dir path",
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--report-interval-iters",
|
93 |
+
default=100,
|
94 |
+
type=int,
|
95 |
+
help="Report interval iterations",
|
96 |
+
)
|
97 |
+
# task related
|
98 |
+
parser.add_argument(
|
99 |
+
"--train-label",
|
100 |
+
type=str,
|
101 |
+
required=required,
|
102 |
+
help="Filename of train label data",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--valid-label",
|
106 |
+
type=str,
|
107 |
+
required=required,
|
108 |
+
help="Filename of validation label data",
|
109 |
+
)
|
110 |
+
parser.add_argument("--test-label", type=str, help="Filename of test label data")
|
111 |
+
parser.add_argument(
|
112 |
+
"--dump-hdf5-path",
|
113 |
+
type=str,
|
114 |
+
default=None,
|
115 |
+
help="Path to dump a preprocessed dataset as hdf5",
|
116 |
+
)
|
117 |
+
# training configuration
|
118 |
+
parser.add_argument("--opt", default="sgd", type=str, help="Optimizer")
|
119 |
+
parser.add_argument(
|
120 |
+
"--sortagrad",
|
121 |
+
default=0,
|
122 |
+
type=int,
|
123 |
+
nargs="?",
|
124 |
+
help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--batchsize",
|
128 |
+
"-b",
|
129 |
+
type=int,
|
130 |
+
default=300,
|
131 |
+
help="Number of examples in each mini-batch",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--accum-grad", type=int, default=1, help="Number of gradient accumueration"
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--epoch",
|
138 |
+
"-e",
|
139 |
+
type=int,
|
140 |
+
default=20,
|
141 |
+
help="Number of sweeps over the dataset to train",
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--early-stop-criterion",
|
145 |
+
default="validation/main/loss",
|
146 |
+
type=str,
|
147 |
+
nargs="?",
|
148 |
+
help="Value to monitor to trigger an early stopping of the training",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--patience",
|
152 |
+
default=3,
|
153 |
+
type=int,
|
154 |
+
nargs="?",
|
155 |
+
help="Number of epochs "
|
156 |
+
"to wait without improvement before stopping the training",
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--schedulers",
|
160 |
+
default=None,
|
161 |
+
action="append",
|
162 |
+
type=lambda kv: kv.split("="),
|
163 |
+
help="optimizer schedulers, you can configure params like:"
|
164 |
+
" <optimizer-param>-<scheduler-name>-<schduler-param>"
|
165 |
+
' e.g., "--schedulers lr=noam --lr-noam-warmup 1000".',
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--gradclip",
|
169 |
+
"-c",
|
170 |
+
type=float,
|
171 |
+
default=5,
|
172 |
+
help="Gradient norm threshold to clip",
|
173 |
+
)
|
174 |
+
parser.add_argument(
|
175 |
+
"--maxlen",
|
176 |
+
type=int,
|
177 |
+
default=40,
|
178 |
+
help="Batch size is reduced if the input sequence > ML",
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--model-module",
|
182 |
+
type=str,
|
183 |
+
default="default",
|
184 |
+
help="model defined module "
|
185 |
+
"(default: espnet.nets.xxx_backend.lm.default:DefaultRNNLM)",
|
186 |
+
)
|
187 |
+
return parser
|
188 |
+
|
189 |
+
|
190 |
+
def main(cmd_args):
|
191 |
+
"""Train LM."""
|
192 |
+
parser = get_parser()
|
193 |
+
args, _ = parser.parse_known_args(cmd_args)
|
194 |
+
if args.backend == "chainer" and args.train_dtype != "float32":
|
195 |
+
raise NotImplementedError(
|
196 |
+
f"chainer backend does not support --train-dtype {args.train_dtype}."
|
197 |
+
"Use --dtype float32."
|
198 |
+
)
|
199 |
+
if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"):
|
200 |
+
raise ValueError(
|
201 |
+
f"--train-dtype {args.train_dtype} does not support the CPU backend."
|
202 |
+
)
|
203 |
+
|
204 |
+
# parse arguments dynamically
|
205 |
+
model_class = dynamic_import_lm(args.model_module, args.backend)
|
206 |
+
model_class.add_arguments(parser)
|
207 |
+
if args.schedulers is not None:
|
208 |
+
for k, v in args.schedulers:
|
209 |
+
scheduler_class = dynamic_import_scheduler(v)
|
210 |
+
scheduler_class.add_arguments(k, parser)
|
211 |
+
|
212 |
+
opt_class = dynamic_import_optimizer(args.opt, args.backend)
|
213 |
+
opt_class.add_arguments(parser)
|
214 |
+
|
215 |
+
args = parser.parse_args(cmd_args)
|
216 |
+
|
217 |
+
# add version info in args
|
218 |
+
args.version = __version__
|
219 |
+
|
220 |
+
# logging info
|
221 |
+
if args.verbose > 0:
|
222 |
+
logging.basicConfig(
|
223 |
+
level=logging.INFO,
|
224 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
logging.basicConfig(
|
228 |
+
level=logging.WARN,
|
229 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
230 |
+
)
|
231 |
+
logging.warning("Skip DEBUG/INFO messages")
|
232 |
+
|
233 |
+
# If --ngpu is not given,
|
234 |
+
# 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
|
235 |
+
# 2. if nvidia-smi exists, use all devices
|
236 |
+
# 3. else ngpu=0
|
237 |
+
if args.ngpu is None:
|
238 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
239 |
+
if cvd is not None:
|
240 |
+
ngpu = len(cvd.split(","))
|
241 |
+
else:
|
242 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
243 |
+
try:
|
244 |
+
p = subprocess.run(
|
245 |
+
["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
246 |
+
)
|
247 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
248 |
+
ngpu = 0
|
249 |
+
else:
|
250 |
+
ngpu = len(p.stderr.decode().split("\n")) - 1
|
251 |
+
args.ngpu = ngpu
|
252 |
+
else:
|
253 |
+
ngpu = args.ngpu
|
254 |
+
logging.info(f"ngpu: {ngpu}")
|
255 |
+
|
256 |
+
# display PYTHONPATH
|
257 |
+
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
|
258 |
+
|
259 |
+
# seed setting
|
260 |
+
nseed = args.seed
|
261 |
+
random.seed(nseed)
|
262 |
+
np.random.seed(nseed)
|
263 |
+
|
264 |
+
# load dictionary
|
265 |
+
with open(args.dict, "rb") as f:
|
266 |
+
dictionary = f.readlines()
|
267 |
+
char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary]
|
268 |
+
char_list.insert(0, "<blank>")
|
269 |
+
char_list.append("<eos>")
|
270 |
+
args.char_list_dict = {x: i for i, x in enumerate(char_list)}
|
271 |
+
args.n_vocab = len(char_list)
|
272 |
+
|
273 |
+
# train
|
274 |
+
logging.info("backend = " + args.backend)
|
275 |
+
if args.backend == "chainer":
|
276 |
+
from espnet.lm.chainer_backend.lm import train
|
277 |
+
|
278 |
+
train(args)
|
279 |
+
elif args.backend == "pytorch":
|
280 |
+
from espnet.lm.pytorch_backend.lm import train
|
281 |
+
|
282 |
+
train(args)
|
283 |
+
else:
|
284 |
+
raise ValueError("Only chainer and pytorch are supported.")
|
285 |
+
|
286 |
+
|
287 |
+
if __name__ == "__main__":
|
288 |
+
main(sys.argv[1:])
|
espnet/bin/mt_train.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# encoding: utf-8
|
3 |
+
|
4 |
+
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Neural machine translation model training script."""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import random
|
12 |
+
import subprocess
|
13 |
+
import sys
|
14 |
+
|
15 |
+
from distutils.version import LooseVersion
|
16 |
+
|
17 |
+
import configargparse
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from espnet import __version__
|
22 |
+
from espnet.utils.cli_utils import strtobool
|
23 |
+
from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES
|
24 |
+
|
25 |
+
is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2")
|
26 |
+
|
27 |
+
|
28 |
+
# NOTE: you need this func to generate our sphinx doc
|
29 |
+
def get_parser(parser=None, required=True):
|
30 |
+
"""Get default arguments."""
|
31 |
+
if parser is None:
|
32 |
+
parser = configargparse.ArgumentParser(
|
33 |
+
description="Train a neural machine translation (NMT) model on one CPU, "
|
34 |
+
"one or multiple GPUs",
|
35 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
36 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
37 |
+
)
|
38 |
+
# general configuration
|
39 |
+
parser.add("--config", is_config_file=True, help="config file path")
|
40 |
+
parser.add(
|
41 |
+
"--config2",
|
42 |
+
is_config_file=True,
|
43 |
+
help="second config file path that overwrites the settings in `--config`.",
|
44 |
+
)
|
45 |
+
parser.add(
|
46 |
+
"--config3",
|
47 |
+
is_config_file=True,
|
48 |
+
help="third config file path that overwrites the settings "
|
49 |
+
"in `--config` and `--config2`.",
|
50 |
+
)
|
51 |
+
|
52 |
+
parser.add_argument(
|
53 |
+
"--ngpu",
|
54 |
+
default=None,
|
55 |
+
type=int,
|
56 |
+
help="Number of GPUs. If not given, use all visible devices",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--train-dtype",
|
60 |
+
default="float32",
|
61 |
+
choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"],
|
62 |
+
help="Data type for training (only pytorch backend). "
|
63 |
+
"O0,O1,.. flags require apex. "
|
64 |
+
"See https://nvidia.github.io/apex/amp.html#opt-levels",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--backend",
|
68 |
+
default="chainer",
|
69 |
+
type=str,
|
70 |
+
choices=["chainer", "pytorch"],
|
71 |
+
help="Backend library",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--outdir", type=str, required=required, help="Output directory"
|
75 |
+
)
|
76 |
+
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
|
77 |
+
parser.add_argument(
|
78 |
+
"--dict", required=required, help="Dictionary for source/target languages"
|
79 |
+
)
|
80 |
+
parser.add_argument("--seed", default=1, type=int, help="Random seed")
|
81 |
+
parser.add_argument("--debugdir", type=str, help="Output directory for debugging")
|
82 |
+
parser.add_argument(
|
83 |
+
"--resume",
|
84 |
+
"-r",
|
85 |
+
default="",
|
86 |
+
nargs="?",
|
87 |
+
help="Resume the training from snapshot",
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--minibatches",
|
91 |
+
"-N",
|
92 |
+
type=int,
|
93 |
+
default="-1",
|
94 |
+
help="Process only N minibatches (for debug)",
|
95 |
+
)
|
96 |
+
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
|
97 |
+
parser.add_argument(
|
98 |
+
"--tensorboard-dir",
|
99 |
+
default=None,
|
100 |
+
type=str,
|
101 |
+
nargs="?",
|
102 |
+
help="Tensorboard log dir path",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--report-interval-iters",
|
106 |
+
default=100,
|
107 |
+
type=int,
|
108 |
+
help="Report interval iterations",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--save-interval-iters",
|
112 |
+
default=0,
|
113 |
+
type=int,
|
114 |
+
help="Save snapshot interval iterations",
|
115 |
+
)
|
116 |
+
# task related
|
117 |
+
parser.add_argument(
|
118 |
+
"--train-json",
|
119 |
+
type=str,
|
120 |
+
default=None,
|
121 |
+
help="Filename of train label data (json)",
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--valid-json",
|
125 |
+
type=str,
|
126 |
+
default=None,
|
127 |
+
help="Filename of validation label data (json)",
|
128 |
+
)
|
129 |
+
# network architecture
|
130 |
+
parser.add_argument(
|
131 |
+
"--model-module",
|
132 |
+
type=str,
|
133 |
+
default=None,
|
134 |
+
help="model defined module (default: espnet.nets.xxx_backend.e2e_mt:E2E)",
|
135 |
+
)
|
136 |
+
# loss related
|
137 |
+
parser.add_argument(
|
138 |
+
"--lsm-weight", default=0.0, type=float, help="Label smoothing weight"
|
139 |
+
)
|
140 |
+
# translations options to compute BLEU
|
141 |
+
parser.add_argument(
|
142 |
+
"--report-bleu",
|
143 |
+
default=True,
|
144 |
+
action="store_true",
|
145 |
+
help="Compute BLEU on development set",
|
146 |
+
)
|
147 |
+
parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
|
148 |
+
parser.add_argument("--beam-size", type=int, default=4, help="Beam size")
|
149 |
+
parser.add_argument("--penalty", default=0.0, type=float, help="Incertion penalty")
|
150 |
+
parser.add_argument(
|
151 |
+
"--maxlenratio",
|
152 |
+
default=0.0,
|
153 |
+
type=float,
|
154 |
+
help="""Input length ratio to obtain max output length.
|
155 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
156 |
+
to automatically find maximum hypothesis lengths""",
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--minlenratio",
|
160 |
+
default=0.0,
|
161 |
+
type=float,
|
162 |
+
help="Input length ratio to obtain min output length",
|
163 |
+
)
|
164 |
+
parser.add_argument(
|
165 |
+
"--rnnlm", type=str, default=None, help="RNNLM model file to read"
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read"
|
169 |
+
)
|
170 |
+
parser.add_argument("--lm-weight", default=0.0, type=float, help="RNNLM weight.")
|
171 |
+
parser.add_argument("--sym-space", default="<space>", type=str, help="Space symbol")
|
172 |
+
parser.add_argument("--sym-blank", default="<blank>", type=str, help="Blank symbol")
|
173 |
+
# minibatch related
|
174 |
+
parser.add_argument(
|
175 |
+
"--sortagrad",
|
176 |
+
default=0,
|
177 |
+
type=int,
|
178 |
+
nargs="?",
|
179 |
+
help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"--batch-count",
|
183 |
+
default="auto",
|
184 |
+
choices=BATCH_COUNT_CHOICES,
|
185 |
+
help="How to count batch_size. "
|
186 |
+
"The default (auto) will find how to count by args.",
|
187 |
+
)
|
188 |
+
parser.add_argument(
|
189 |
+
"--batch-size",
|
190 |
+
"--batch-seqs",
|
191 |
+
"-b",
|
192 |
+
default=0,
|
193 |
+
type=int,
|
194 |
+
help="Maximum seqs in a minibatch (0 to disable)",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--batch-bins",
|
198 |
+
default=0,
|
199 |
+
type=int,
|
200 |
+
help="Maximum bins in a minibatch (0 to disable)",
|
201 |
+
)
|
202 |
+
parser.add_argument(
|
203 |
+
"--batch-frames-in",
|
204 |
+
default=0,
|
205 |
+
type=int,
|
206 |
+
help="Maximum input frames in a minibatch (0 to disable)",
|
207 |
+
)
|
208 |
+
parser.add_argument(
|
209 |
+
"--batch-frames-out",
|
210 |
+
default=0,
|
211 |
+
type=int,
|
212 |
+
help="Maximum output frames in a minibatch (0 to disable)",
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"--batch-frames-inout",
|
216 |
+
default=0,
|
217 |
+
type=int,
|
218 |
+
help="Maximum input+output frames in a minibatch (0 to disable)",
|
219 |
+
)
|
220 |
+
parser.add_argument(
|
221 |
+
"--maxlen-in",
|
222 |
+
"--batch-seq-maxlen-in",
|
223 |
+
default=100,
|
224 |
+
type=int,
|
225 |
+
metavar="ML",
|
226 |
+
help="When --batch-count=seq, "
|
227 |
+
"batch size is reduced if the input sequence length > ML.",
|
228 |
+
)
|
229 |
+
parser.add_argument(
|
230 |
+
"--maxlen-out",
|
231 |
+
"--batch-seq-maxlen-out",
|
232 |
+
default=100,
|
233 |
+
type=int,
|
234 |
+
metavar="ML",
|
235 |
+
help="When --batch-count=seq, "
|
236 |
+
"batch size is reduced if the output sequence length > ML",
|
237 |
+
)
|
238 |
+
parser.add_argument(
|
239 |
+
"--n-iter-processes",
|
240 |
+
default=0,
|
241 |
+
type=int,
|
242 |
+
help="Number of processes of iterator",
|
243 |
+
)
|
244 |
+
# optimization related
|
245 |
+
parser.add_argument(
|
246 |
+
"--opt",
|
247 |
+
default="adadelta",
|
248 |
+
type=str,
|
249 |
+
choices=["adadelta", "adam", "noam"],
|
250 |
+
help="Optimizer",
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"--accum-grad", default=1, type=int, help="Number of gradient accumuration"
|
254 |
+
)
|
255 |
+
parser.add_argument(
|
256 |
+
"--eps", default=1e-8, type=float, help="Epsilon constant for optimizer"
|
257 |
+
)
|
258 |
+
parser.add_argument(
|
259 |
+
"--eps-decay", default=0.01, type=float, help="Decaying ratio of epsilon"
|
260 |
+
)
|
261 |
+
parser.add_argument(
|
262 |
+
"--lr", default=1e-3, type=float, help="Learning rate for optimizer"
|
263 |
+
)
|
264 |
+
parser.add_argument(
|
265 |
+
"--lr-decay", default=1.0, type=float, help="Decaying ratio of learning rate"
|
266 |
+
)
|
267 |
+
parser.add_argument(
|
268 |
+
"--weight-decay", default=0.0, type=float, help="Weight decay ratio"
|
269 |
+
)
|
270 |
+
parser.add_argument(
|
271 |
+
"--criterion",
|
272 |
+
default="acc",
|
273 |
+
type=str,
|
274 |
+
choices=["loss", "acc"],
|
275 |
+
help="Criterion to perform epsilon decay",
|
276 |
+
)
|
277 |
+
parser.add_argument(
|
278 |
+
"--threshold", default=1e-4, type=float, help="Threshold to stop iteration"
|
279 |
+
)
|
280 |
+
parser.add_argument(
|
281 |
+
"--epochs", "-e", default=30, type=int, help="Maximum number of epochs"
|
282 |
+
)
|
283 |
+
parser.add_argument(
|
284 |
+
"--early-stop-criterion",
|
285 |
+
default="validation/main/acc",
|
286 |
+
type=str,
|
287 |
+
nargs="?",
|
288 |
+
help="Value to monitor to trigger an early stopping of the training",
|
289 |
+
)
|
290 |
+
parser.add_argument(
|
291 |
+
"--patience",
|
292 |
+
default=3,
|
293 |
+
type=int,
|
294 |
+
nargs="?",
|
295 |
+
help="Number of epochs to wait "
|
296 |
+
"without improvement before stopping the training",
|
297 |
+
)
|
298 |
+
parser.add_argument(
|
299 |
+
"--grad-clip", default=5, type=float, help="Gradient norm threshold to clip"
|
300 |
+
)
|
301 |
+
parser.add_argument(
|
302 |
+
"--num-save-attention",
|
303 |
+
default=3,
|
304 |
+
type=int,
|
305 |
+
help="Number of samples of attention to be saved",
|
306 |
+
)
|
307 |
+
# decoder related
|
308 |
+
parser.add_argument(
|
309 |
+
"--context-residual",
|
310 |
+
default=False,
|
311 |
+
type=strtobool,
|
312 |
+
nargs="?",
|
313 |
+
help="The flag to switch to use context vector residual in the decoder network",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--tie-src-tgt-embedding",
|
317 |
+
default=False,
|
318 |
+
type=strtobool,
|
319 |
+
nargs="?",
|
320 |
+
help="Tie parameters of source embedding and target embedding.",
|
321 |
+
)
|
322 |
+
parser.add_argument(
|
323 |
+
"--tie-classifier",
|
324 |
+
default=False,
|
325 |
+
type=strtobool,
|
326 |
+
nargs="?",
|
327 |
+
help="Tie parameters of target embedding and output projection layer.",
|
328 |
+
)
|
329 |
+
# finetuning related
|
330 |
+
parser.add_argument(
|
331 |
+
"--enc-init",
|
332 |
+
default=None,
|
333 |
+
type=str,
|
334 |
+
nargs="?",
|
335 |
+
help="Pre-trained ASR model to initialize encoder.",
|
336 |
+
)
|
337 |
+
parser.add_argument(
|
338 |
+
"--enc-init-mods",
|
339 |
+
default="enc.enc.",
|
340 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
341 |
+
help="List of encoder modules to initialize, separated by a comma.",
|
342 |
+
)
|
343 |
+
parser.add_argument(
|
344 |
+
"--dec-init",
|
345 |
+
default=None,
|
346 |
+
type=str,
|
347 |
+
nargs="?",
|
348 |
+
help="Pre-trained ASR, MT or LM model to initialize decoder.",
|
349 |
+
)
|
350 |
+
parser.add_argument(
|
351 |
+
"--dec-init-mods",
|
352 |
+
default="att., dec.",
|
353 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
354 |
+
help="List of decoder modules to initialize, separated by a comma.",
|
355 |
+
)
|
356 |
+
# multilingual related
|
357 |
+
parser.add_argument(
|
358 |
+
"--multilingual",
|
359 |
+
default=False,
|
360 |
+
type=strtobool,
|
361 |
+
help="Prepend target language ID to the source sentence. "
|
362 |
+
"Both source/target language IDs must be prepend in the pre-processing stage.",
|
363 |
+
)
|
364 |
+
parser.add_argument(
|
365 |
+
"--replace-sos",
|
366 |
+
default=False,
|
367 |
+
type=strtobool,
|
368 |
+
help="Replace <sos> in the decoder with a target language ID "
|
369 |
+
"(the first token in the target sequence)",
|
370 |
+
)
|
371 |
+
|
372 |
+
return parser
|
373 |
+
|
374 |
+
|
375 |
+
def main(cmd_args):
|
376 |
+
"""Run the main training function."""
|
377 |
+
parser = get_parser()
|
378 |
+
args, _ = parser.parse_known_args(cmd_args)
|
379 |
+
if args.backend == "chainer" and args.train_dtype != "float32":
|
380 |
+
raise NotImplementedError(
|
381 |
+
f"chainer backend does not support --train-dtype {args.train_dtype}."
|
382 |
+
"Use --dtype float32."
|
383 |
+
)
|
384 |
+
if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"):
|
385 |
+
raise ValueError(
|
386 |
+
f"--train-dtype {args.train_dtype} does not support the CPU backend."
|
387 |
+
)
|
388 |
+
|
389 |
+
from espnet.utils.dynamic_import import dynamic_import
|
390 |
+
|
391 |
+
if args.model_module is None:
|
392 |
+
model_module = "espnet.nets." + args.backend + "_backend.e2e_mt:E2E"
|
393 |
+
else:
|
394 |
+
model_module = args.model_module
|
395 |
+
model_class = dynamic_import(model_module)
|
396 |
+
model_class.add_arguments(parser)
|
397 |
+
|
398 |
+
args = parser.parse_args(cmd_args)
|
399 |
+
args.model_module = model_module
|
400 |
+
if "chainer_backend" in args.model_module:
|
401 |
+
args.backend = "chainer"
|
402 |
+
if "pytorch_backend" in args.model_module:
|
403 |
+
args.backend = "pytorch"
|
404 |
+
|
405 |
+
# add version info in args
|
406 |
+
args.version = __version__
|
407 |
+
|
408 |
+
# logging info
|
409 |
+
if args.verbose > 0:
|
410 |
+
logging.basicConfig(
|
411 |
+
level=logging.INFO,
|
412 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
413 |
+
)
|
414 |
+
else:
|
415 |
+
logging.basicConfig(
|
416 |
+
level=logging.WARN,
|
417 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
418 |
+
)
|
419 |
+
logging.warning("Skip DEBUG/INFO messages")
|
420 |
+
|
421 |
+
# If --ngpu is not given,
|
422 |
+
# 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
|
423 |
+
# 2. if nvidia-smi exists, use all devices
|
424 |
+
# 3. else ngpu=0
|
425 |
+
if args.ngpu is None:
|
426 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
427 |
+
if cvd is not None:
|
428 |
+
ngpu = len(cvd.split(","))
|
429 |
+
else:
|
430 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
431 |
+
try:
|
432 |
+
p = subprocess.run(
|
433 |
+
["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
434 |
+
)
|
435 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
436 |
+
ngpu = 0
|
437 |
+
else:
|
438 |
+
ngpu = len(p.stderr.decode().split("\n")) - 1
|
439 |
+
args.ngpu = ngpu
|
440 |
+
else:
|
441 |
+
if is_torch_1_2_plus and args.ngpu != 1:
|
442 |
+
logging.debug(
|
443 |
+
"There are some bugs with multi-GPU processing in PyTorch 1.2+"
|
444 |
+
+ " (see https://github.com/pytorch/pytorch/issues/21108)"
|
445 |
+
)
|
446 |
+
ngpu = args.ngpu
|
447 |
+
logging.info(f"ngpu: {ngpu}")
|
448 |
+
|
449 |
+
# display PYTHONPATH
|
450 |
+
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
|
451 |
+
|
452 |
+
# set random seed
|
453 |
+
logging.info("random seed = %d" % args.seed)
|
454 |
+
random.seed(args.seed)
|
455 |
+
np.random.seed(args.seed)
|
456 |
+
|
457 |
+
# load dictionary for debug log
|
458 |
+
if args.dict is not None:
|
459 |
+
with open(args.dict, "rb") as f:
|
460 |
+
dictionary = f.readlines()
|
461 |
+
char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary]
|
462 |
+
char_list.insert(0, "<blank>")
|
463 |
+
char_list.append("<eos>")
|
464 |
+
args.char_list = char_list
|
465 |
+
else:
|
466 |
+
args.char_list = None
|
467 |
+
|
468 |
+
# train
|
469 |
+
logging.info("backend = " + args.backend)
|
470 |
+
|
471 |
+
if args.backend == "pytorch":
|
472 |
+
from espnet.mt.pytorch_backend.mt import train
|
473 |
+
|
474 |
+
train(args)
|
475 |
+
else:
|
476 |
+
raise ValueError("Only pytorch are supported.")
|
477 |
+
|
478 |
+
|
479 |
+
if __name__ == "__main__":
|
480 |
+
main(sys.argv[1:])
|
espnet/bin/mt_trans.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# encoding: utf-8
|
3 |
+
|
4 |
+
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Neural machine translation model decoding script."""
|
8 |
+
|
9 |
+
import configargparse
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import random
|
13 |
+
import sys
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
|
18 |
+
# NOTE: you need this func to generate our sphinx doc
|
19 |
+
def get_parser():
|
20 |
+
"""Get default arguments."""
|
21 |
+
parser = configargparse.ArgumentParser(
|
22 |
+
description="Translate text from speech "
|
23 |
+
"using a speech translation model on one CPU or GPU",
|
24 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
25 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
26 |
+
)
|
27 |
+
# general configuration
|
28 |
+
parser.add("--config", is_config_file=True, help="Config file path")
|
29 |
+
parser.add(
|
30 |
+
"--config2",
|
31 |
+
is_config_file=True,
|
32 |
+
help="Second config file path that overwrites the settings in `--config`",
|
33 |
+
)
|
34 |
+
parser.add(
|
35 |
+
"--config3",
|
36 |
+
is_config_file=True,
|
37 |
+
help="Third config file path "
|
38 |
+
"that overwrites the settings in `--config` and `--config2`",
|
39 |
+
)
|
40 |
+
|
41 |
+
parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
|
42 |
+
parser.add_argument(
|
43 |
+
"--dtype",
|
44 |
+
choices=("float16", "float32", "float64"),
|
45 |
+
default="float32",
|
46 |
+
help="Float precision (only available in --api v2)",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--backend",
|
50 |
+
type=str,
|
51 |
+
default="chainer",
|
52 |
+
choices=["chainer", "pytorch"],
|
53 |
+
help="Backend library",
|
54 |
+
)
|
55 |
+
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
|
56 |
+
parser.add_argument("--seed", type=int, default=1, help="Random seed")
|
57 |
+
parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
|
58 |
+
parser.add_argument(
|
59 |
+
"--batchsize",
|
60 |
+
type=int,
|
61 |
+
default=1,
|
62 |
+
help="Batch size for beam search (0: means no batch processing)",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--preprocess-conf",
|
66 |
+
type=str,
|
67 |
+
default=None,
|
68 |
+
help="The configuration file for the pre-processing",
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--api",
|
72 |
+
default="v1",
|
73 |
+
choices=["v1", "v2"],
|
74 |
+
help="Beam search APIs "
|
75 |
+
"v1: Default API. It only supports "
|
76 |
+
"the ASRInterface.recognize method and DefaultRNNLM. "
|
77 |
+
"v2: Experimental API. "
|
78 |
+
"It supports any models that implements ScorerInterface.",
|
79 |
+
)
|
80 |
+
# task related
|
81 |
+
parser.add_argument(
|
82 |
+
"--trans-json", type=str, help="Filename of translation data (json)"
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--result-label",
|
86 |
+
type=str,
|
87 |
+
required=True,
|
88 |
+
help="Filename of result label data (json)",
|
89 |
+
)
|
90 |
+
# model (parameter) related
|
91 |
+
parser.add_argument(
|
92 |
+
"--model", type=str, required=True, help="Model file parameters to read"
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--model-conf", type=str, default=None, help="Model config file"
|
96 |
+
)
|
97 |
+
# search related
|
98 |
+
parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
|
99 |
+
parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
|
100 |
+
parser.add_argument("--penalty", type=float, default=0.1, help="Incertion penalty")
|
101 |
+
parser.add_argument(
|
102 |
+
"--maxlenratio",
|
103 |
+
type=float,
|
104 |
+
default=3.0,
|
105 |
+
help="""Input length ratio to obtain max output length.
|
106 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
107 |
+
to automatically find maximum hypothesis lengths""",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--minlenratio",
|
111 |
+
type=float,
|
112 |
+
default=0.0,
|
113 |
+
help="Input length ratio to obtain min output length",
|
114 |
+
)
|
115 |
+
# multilingual related
|
116 |
+
parser.add_argument(
|
117 |
+
"--tgt-lang",
|
118 |
+
default=False,
|
119 |
+
type=str,
|
120 |
+
help="target language ID (e.g., <en>, <de>, and <fr> etc.)",
|
121 |
+
)
|
122 |
+
return parser
|
123 |
+
|
124 |
+
|
125 |
+
def main(args):
|
126 |
+
"""Run the main decoding function."""
|
127 |
+
parser = get_parser()
|
128 |
+
args = parser.parse_args(args)
|
129 |
+
|
130 |
+
# logging info
|
131 |
+
if args.verbose == 1:
|
132 |
+
logging.basicConfig(
|
133 |
+
level=logging.INFO,
|
134 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
135 |
+
)
|
136 |
+
elif args.verbose == 2:
|
137 |
+
logging.basicConfig(
|
138 |
+
level=logging.DEBUG,
|
139 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
logging.basicConfig(
|
143 |
+
level=logging.WARN,
|
144 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
145 |
+
)
|
146 |
+
logging.warning("Skip DEBUG/INFO messages")
|
147 |
+
|
148 |
+
# check CUDA_VISIBLE_DEVICES
|
149 |
+
if args.ngpu > 0:
|
150 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
151 |
+
if cvd is None:
|
152 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
153 |
+
elif args.ngpu != len(cvd.split(",")):
|
154 |
+
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
|
155 |
+
sys.exit(1)
|
156 |
+
|
157 |
+
# TODO(mn5k): support of multiple GPUs
|
158 |
+
if args.ngpu > 1:
|
159 |
+
logging.error("The program only supports ngpu=1.")
|
160 |
+
sys.exit(1)
|
161 |
+
|
162 |
+
# display PYTHONPATH
|
163 |
+
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
|
164 |
+
|
165 |
+
# seed setting
|
166 |
+
random.seed(args.seed)
|
167 |
+
np.random.seed(args.seed)
|
168 |
+
logging.info("set random seed = %d" % args.seed)
|
169 |
+
|
170 |
+
# trans
|
171 |
+
logging.info("backend = " + args.backend)
|
172 |
+
if args.backend == "pytorch":
|
173 |
+
# Experimental API that supports custom LMs
|
174 |
+
from espnet.mt.pytorch_backend.mt import trans
|
175 |
+
|
176 |
+
if args.dtype != "float32":
|
177 |
+
raise NotImplementedError(
|
178 |
+
f"`--dtype {args.dtype}` is only available with `--api v2`"
|
179 |
+
)
|
180 |
+
trans(args)
|
181 |
+
else:
|
182 |
+
raise ValueError("Only pytorch are supported.")
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == "__main__":
|
186 |
+
main(sys.argv[1:])
|
espnet/bin/st_train.py
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# encoding: utf-8
|
3 |
+
|
4 |
+
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""End-to-end speech translation model training script."""
|
8 |
+
|
9 |
+
from distutils.version import LooseVersion
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import random
|
13 |
+
import subprocess
|
14 |
+
import sys
|
15 |
+
|
16 |
+
import configargparse
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from espnet import __version__
|
21 |
+
from espnet.utils.cli_utils import strtobool
|
22 |
+
from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES
|
23 |
+
|
24 |
+
is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2")
|
25 |
+
|
26 |
+
|
27 |
+
# NOTE: you need this func to generate our sphinx doc
|
28 |
+
def get_parser(parser=None, required=True):
|
29 |
+
"""Get default arguments."""
|
30 |
+
if parser is None:
|
31 |
+
parser = configargparse.ArgumentParser(
|
32 |
+
description="Train a speech translation (ST) model on one CPU, "
|
33 |
+
"one or multiple GPUs",
|
34 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
35 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
36 |
+
)
|
37 |
+
# general configuration
|
38 |
+
parser.add("--config", is_config_file=True, help="config file path")
|
39 |
+
parser.add(
|
40 |
+
"--config2",
|
41 |
+
is_config_file=True,
|
42 |
+
help="second config file path that overwrites the settings in `--config`.",
|
43 |
+
)
|
44 |
+
parser.add(
|
45 |
+
"--config3",
|
46 |
+
is_config_file=True,
|
47 |
+
help="third config file path that overwrites the settings "
|
48 |
+
"in `--config` and `--config2`.",
|
49 |
+
)
|
50 |
+
|
51 |
+
parser.add_argument(
|
52 |
+
"--ngpu",
|
53 |
+
default=None,
|
54 |
+
type=int,
|
55 |
+
help="Number of GPUs. If not given, use all visible devices",
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--train-dtype",
|
59 |
+
default="float32",
|
60 |
+
choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"],
|
61 |
+
help="Data type for training (only pytorch backend). "
|
62 |
+
"O0,O1,.. flags require apex. "
|
63 |
+
"See https://nvidia.github.io/apex/amp.html#opt-levels",
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--backend",
|
67 |
+
default="chainer",
|
68 |
+
type=str,
|
69 |
+
choices=["chainer", "pytorch"],
|
70 |
+
help="Backend library",
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--outdir", type=str, required=required, help="Output directory"
|
74 |
+
)
|
75 |
+
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
|
76 |
+
parser.add_argument("--dict", required=required, help="Dictionary")
|
77 |
+
parser.add_argument("--seed", default=1, type=int, help="Random seed")
|
78 |
+
parser.add_argument("--debugdir", type=str, help="Output directory for debugging")
|
79 |
+
parser.add_argument(
|
80 |
+
"--resume",
|
81 |
+
"-r",
|
82 |
+
default="",
|
83 |
+
nargs="?",
|
84 |
+
help="Resume the training from snapshot",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--minibatches",
|
88 |
+
"-N",
|
89 |
+
type=int,
|
90 |
+
default="-1",
|
91 |
+
help="Process only N minibatches (for debug)",
|
92 |
+
)
|
93 |
+
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
|
94 |
+
parser.add_argument(
|
95 |
+
"--tensorboard-dir",
|
96 |
+
default=None,
|
97 |
+
type=str,
|
98 |
+
nargs="?",
|
99 |
+
help="Tensorboard log dir path",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--report-interval-iters",
|
103 |
+
default=100,
|
104 |
+
type=int,
|
105 |
+
help="Report interval iterations",
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--save-interval-iters",
|
109 |
+
default=0,
|
110 |
+
type=int,
|
111 |
+
help="Save snapshot interval iterations",
|
112 |
+
)
|
113 |
+
# task related
|
114 |
+
parser.add_argument(
|
115 |
+
"--train-json",
|
116 |
+
type=str,
|
117 |
+
default=None,
|
118 |
+
help="Filename of train label data (json)",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--valid-json",
|
122 |
+
type=str,
|
123 |
+
default=None,
|
124 |
+
help="Filename of validation label data (json)",
|
125 |
+
)
|
126 |
+
# network architecture
|
127 |
+
parser.add_argument(
|
128 |
+
"--model-module",
|
129 |
+
type=str,
|
130 |
+
default=None,
|
131 |
+
help="model defined module (default: espnet.nets.xxx_backend.e2e_st:E2E)",
|
132 |
+
)
|
133 |
+
# loss related
|
134 |
+
parser.add_argument(
|
135 |
+
"--ctc_type",
|
136 |
+
default="warpctc",
|
137 |
+
type=str,
|
138 |
+
choices=["builtin", "warpctc", "gtnctc", "cudnnctc"],
|
139 |
+
help="Type of CTC implementation to calculate loss.",
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--mtlalpha",
|
143 |
+
default=0.0,
|
144 |
+
type=float,
|
145 |
+
help="Multitask learning coefficient, alpha: \
|
146 |
+
alpha*ctc_loss + (1-alpha)*att_loss",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--asr-weight",
|
150 |
+
default=0.0,
|
151 |
+
type=float,
|
152 |
+
help="Multitask learning coefficient for ASR task, weight: "
|
153 |
+
" asr_weight*(alpha*ctc_loss + (1-alpha)*att_loss)"
|
154 |
+
" + (1-asr_weight-mt_weight)*st_loss",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--mt-weight",
|
158 |
+
default=0.0,
|
159 |
+
type=float,
|
160 |
+
help="Multitask learning coefficient for MT task, weight: \
|
161 |
+
mt_weight*mt_loss + (1-mt_weight-asr_weight)*st_loss",
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
"--lsm-weight", default=0.0, type=float, help="Label smoothing weight"
|
165 |
+
)
|
166 |
+
# recognition options to compute CER/WER
|
167 |
+
parser.add_argument(
|
168 |
+
"--report-cer",
|
169 |
+
default=False,
|
170 |
+
action="store_true",
|
171 |
+
help="Compute CER on development set",
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--report-wer",
|
175 |
+
default=False,
|
176 |
+
action="store_true",
|
177 |
+
help="Compute WER on development set",
|
178 |
+
)
|
179 |
+
# translations options to compute BLEU
|
180 |
+
parser.add_argument(
|
181 |
+
"--report-bleu",
|
182 |
+
default=True,
|
183 |
+
action="store_true",
|
184 |
+
help="Compute BLEU on development set",
|
185 |
+
)
|
186 |
+
parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
|
187 |
+
parser.add_argument("--beam-size", type=int, default=4, help="Beam size")
|
188 |
+
parser.add_argument("--penalty", default=0.0, type=float, help="Incertion penalty")
|
189 |
+
parser.add_argument(
|
190 |
+
"--maxlenratio",
|
191 |
+
default=0.0,
|
192 |
+
type=float,
|
193 |
+
help="""Input length ratio to obtain max output length.
|
194 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
195 |
+
to automatically find maximum hypothesis lengths""",
|
196 |
+
)
|
197 |
+
parser.add_argument(
|
198 |
+
"--minlenratio",
|
199 |
+
default=0.0,
|
200 |
+
type=float,
|
201 |
+
help="Input length ratio to obtain min output length",
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--rnnlm", type=str, default=None, help="RNNLM model file to read"
|
205 |
+
)
|
206 |
+
parser.add_argument(
|
207 |
+
"--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read"
|
208 |
+
)
|
209 |
+
parser.add_argument("--lm-weight", default=0.0, type=float, help="RNNLM weight.")
|
210 |
+
parser.add_argument("--sym-space", default="<space>", type=str, help="Space symbol")
|
211 |
+
parser.add_argument("--sym-blank", default="<blank>", type=str, help="Blank symbol")
|
212 |
+
# minibatch related
|
213 |
+
parser.add_argument(
|
214 |
+
"--sortagrad",
|
215 |
+
default=0,
|
216 |
+
type=int,
|
217 |
+
nargs="?",
|
218 |
+
help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
|
219 |
+
)
|
220 |
+
parser.add_argument(
|
221 |
+
"--batch-count",
|
222 |
+
default="auto",
|
223 |
+
choices=BATCH_COUNT_CHOICES,
|
224 |
+
help="How to count batch_size. "
|
225 |
+
"The default (auto) will find how to count by args.",
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
"--batch-size",
|
229 |
+
"--batch-seqs",
|
230 |
+
"-b",
|
231 |
+
default=0,
|
232 |
+
type=int,
|
233 |
+
help="Maximum seqs in a minibatch (0 to disable)",
|
234 |
+
)
|
235 |
+
parser.add_argument(
|
236 |
+
"--batch-bins",
|
237 |
+
default=0,
|
238 |
+
type=int,
|
239 |
+
help="Maximum bins in a minibatch (0 to disable)",
|
240 |
+
)
|
241 |
+
parser.add_argument(
|
242 |
+
"--batch-frames-in",
|
243 |
+
default=0,
|
244 |
+
type=int,
|
245 |
+
help="Maximum input frames in a minibatch (0 to disable)",
|
246 |
+
)
|
247 |
+
parser.add_argument(
|
248 |
+
"--batch-frames-out",
|
249 |
+
default=0,
|
250 |
+
type=int,
|
251 |
+
help="Maximum output frames in a minibatch (0 to disable)",
|
252 |
+
)
|
253 |
+
parser.add_argument(
|
254 |
+
"--batch-frames-inout",
|
255 |
+
default=0,
|
256 |
+
type=int,
|
257 |
+
help="Maximum input+output frames in a minibatch (0 to disable)",
|
258 |
+
)
|
259 |
+
parser.add_argument(
|
260 |
+
"--maxlen-in",
|
261 |
+
"--batch-seq-maxlen-in",
|
262 |
+
default=800,
|
263 |
+
type=int,
|
264 |
+
metavar="ML",
|
265 |
+
help="When --batch-count=seq, batch size is reduced "
|
266 |
+
"if the input sequence length > ML.",
|
267 |
+
)
|
268 |
+
parser.add_argument(
|
269 |
+
"--maxlen-out",
|
270 |
+
"--batch-seq-maxlen-out",
|
271 |
+
default=150,
|
272 |
+
type=int,
|
273 |
+
metavar="ML",
|
274 |
+
help="When --batch-count=seq, "
|
275 |
+
"batch size is reduced if the output sequence length > ML",
|
276 |
+
)
|
277 |
+
parser.add_argument(
|
278 |
+
"--n-iter-processes",
|
279 |
+
default=0,
|
280 |
+
type=int,
|
281 |
+
help="Number of processes of iterator",
|
282 |
+
)
|
283 |
+
parser.add_argument(
|
284 |
+
"--preprocess-conf",
|
285 |
+
type=str,
|
286 |
+
default=None,
|
287 |
+
nargs="?",
|
288 |
+
help="The configuration file for the pre-processing",
|
289 |
+
)
|
290 |
+
# optimization related
|
291 |
+
parser.add_argument(
|
292 |
+
"--opt",
|
293 |
+
default="adadelta",
|
294 |
+
type=str,
|
295 |
+
choices=["adadelta", "adam", "noam"],
|
296 |
+
help="Optimizer",
|
297 |
+
)
|
298 |
+
parser.add_argument(
|
299 |
+
"--accum-grad", default=1, type=int, help="Number of gradient accumuration"
|
300 |
+
)
|
301 |
+
parser.add_argument(
|
302 |
+
"--eps", default=1e-8, type=float, help="Epsilon constant for optimizer"
|
303 |
+
)
|
304 |
+
parser.add_argument(
|
305 |
+
"--eps-decay", default=0.01, type=float, help="Decaying ratio of epsilon"
|
306 |
+
)
|
307 |
+
parser.add_argument(
|
308 |
+
"--lr", default=1e-3, type=float, help="Learning rate for optimizer"
|
309 |
+
)
|
310 |
+
parser.add_argument(
|
311 |
+
"--lr-decay", default=1.0, type=float, help="Decaying ratio of learning rate"
|
312 |
+
)
|
313 |
+
parser.add_argument(
|
314 |
+
"--weight-decay", default=0.0, type=float, help="Weight decay ratio"
|
315 |
+
)
|
316 |
+
parser.add_argument(
|
317 |
+
"--criterion",
|
318 |
+
default="acc",
|
319 |
+
type=str,
|
320 |
+
choices=["loss", "acc"],
|
321 |
+
help="Criterion to perform epsilon decay",
|
322 |
+
)
|
323 |
+
parser.add_argument(
|
324 |
+
"--threshold", default=1e-4, type=float, help="Threshold to stop iteration"
|
325 |
+
)
|
326 |
+
parser.add_argument(
|
327 |
+
"--epochs", "-e", default=30, type=int, help="Maximum number of epochs"
|
328 |
+
)
|
329 |
+
parser.add_argument(
|
330 |
+
"--early-stop-criterion",
|
331 |
+
default="validation/main/acc",
|
332 |
+
type=str,
|
333 |
+
nargs="?",
|
334 |
+
help="Value to monitor to trigger an early stopping of the training",
|
335 |
+
)
|
336 |
+
parser.add_argument(
|
337 |
+
"--patience",
|
338 |
+
default=3,
|
339 |
+
type=int,
|
340 |
+
nargs="?",
|
341 |
+
help="Number of epochs to wait "
|
342 |
+
"without improvement before stopping the training",
|
343 |
+
)
|
344 |
+
parser.add_argument(
|
345 |
+
"--grad-clip", default=5, type=float, help="Gradient norm threshold to clip"
|
346 |
+
)
|
347 |
+
parser.add_argument(
|
348 |
+
"--num-save-attention",
|
349 |
+
default=3,
|
350 |
+
type=int,
|
351 |
+
help="Number of samples of attention to be saved",
|
352 |
+
)
|
353 |
+
parser.add_argument(
|
354 |
+
"--num-save-ctc",
|
355 |
+
default=3,
|
356 |
+
type=int,
|
357 |
+
help="Number of samples of CTC probability to be saved",
|
358 |
+
)
|
359 |
+
parser.add_argument(
|
360 |
+
"--grad-noise",
|
361 |
+
type=strtobool,
|
362 |
+
default=False,
|
363 |
+
help="The flag to switch to use noise injection to gradients during training",
|
364 |
+
)
|
365 |
+
# speech translation related
|
366 |
+
parser.add_argument(
|
367 |
+
"--context-residual",
|
368 |
+
default=False,
|
369 |
+
type=strtobool,
|
370 |
+
nargs="?",
|
371 |
+
help="The flag to switch to use context vector residual in the decoder network",
|
372 |
+
)
|
373 |
+
# finetuning related
|
374 |
+
parser.add_argument(
|
375 |
+
"--enc-init",
|
376 |
+
default=None,
|
377 |
+
type=str,
|
378 |
+
nargs="?",
|
379 |
+
help="Pre-trained ASR model to initialize encoder.",
|
380 |
+
)
|
381 |
+
parser.add_argument(
|
382 |
+
"--enc-init-mods",
|
383 |
+
default="enc.enc.",
|
384 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
385 |
+
help="List of encoder modules to initialize, separated by a comma.",
|
386 |
+
)
|
387 |
+
parser.add_argument(
|
388 |
+
"--dec-init",
|
389 |
+
default=None,
|
390 |
+
type=str,
|
391 |
+
nargs="?",
|
392 |
+
help="Pre-trained ASR, MT or LM model to initialize decoder.",
|
393 |
+
)
|
394 |
+
parser.add_argument(
|
395 |
+
"--dec-init-mods",
|
396 |
+
default="att., dec.",
|
397 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
398 |
+
help="List of decoder modules to initialize, separated by a comma.",
|
399 |
+
)
|
400 |
+
# multilingual related
|
401 |
+
parser.add_argument(
|
402 |
+
"--multilingual",
|
403 |
+
default=False,
|
404 |
+
type=strtobool,
|
405 |
+
help="Prepend target language ID to the source sentence. "
|
406 |
+
" Both source/target language IDs must be prepend in the pre-processing stage.",
|
407 |
+
)
|
408 |
+
parser.add_argument(
|
409 |
+
"--replace-sos",
|
410 |
+
default=False,
|
411 |
+
type=strtobool,
|
412 |
+
help="Replace <sos> in the decoder with a target language ID \
|
413 |
+
(the first token in the target sequence)",
|
414 |
+
)
|
415 |
+
# Feature transform: Normalization
|
416 |
+
parser.add_argument(
|
417 |
+
"--stats-file",
|
418 |
+
type=str,
|
419 |
+
default=None,
|
420 |
+
help="The stats file for the feature normalization",
|
421 |
+
)
|
422 |
+
parser.add_argument(
|
423 |
+
"--apply-uttmvn",
|
424 |
+
type=strtobool,
|
425 |
+
default=True,
|
426 |
+
help="Apply utterance level mean " "variance normalization.",
|
427 |
+
)
|
428 |
+
parser.add_argument("--uttmvn-norm-means", type=strtobool, default=True, help="")
|
429 |
+
parser.add_argument("--uttmvn-norm-vars", type=strtobool, default=False, help="")
|
430 |
+
# Feature transform: Fbank
|
431 |
+
parser.add_argument(
|
432 |
+
"--fbank-fs",
|
433 |
+
type=int,
|
434 |
+
default=16000,
|
435 |
+
help="The sample frequency used for " "the mel-fbank creation.",
|
436 |
+
)
|
437 |
+
parser.add_argument(
|
438 |
+
"--n-mels", type=int, default=80, help="The number of mel-frequency bins."
|
439 |
+
)
|
440 |
+
parser.add_argument("--fbank-fmin", type=float, default=0.0, help="")
|
441 |
+
parser.add_argument("--fbank-fmax", type=float, default=None, help="")
|
442 |
+
return parser
|
443 |
+
|
444 |
+
|
445 |
+
def main(cmd_args):
|
446 |
+
"""Run the main training function."""
|
447 |
+
parser = get_parser()
|
448 |
+
args, _ = parser.parse_known_args(cmd_args)
|
449 |
+
if args.backend == "chainer" and args.train_dtype != "float32":
|
450 |
+
raise NotImplementedError(
|
451 |
+
f"chainer backend does not support --train-dtype {args.train_dtype}."
|
452 |
+
"Use --dtype float32."
|
453 |
+
)
|
454 |
+
if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"):
|
455 |
+
raise ValueError(
|
456 |
+
f"--train-dtype {args.train_dtype} does not support the CPU backend."
|
457 |
+
)
|
458 |
+
|
459 |
+
from espnet.utils.dynamic_import import dynamic_import
|
460 |
+
|
461 |
+
if args.model_module is None:
|
462 |
+
model_module = "espnet.nets." + args.backend + "_backend.e2e_st:E2E"
|
463 |
+
else:
|
464 |
+
model_module = args.model_module
|
465 |
+
model_class = dynamic_import(model_module)
|
466 |
+
model_class.add_arguments(parser)
|
467 |
+
|
468 |
+
args = parser.parse_args(cmd_args)
|
469 |
+
args.model_module = model_module
|
470 |
+
if "chainer_backend" in args.model_module:
|
471 |
+
args.backend = "chainer"
|
472 |
+
if "pytorch_backend" in args.model_module:
|
473 |
+
args.backend = "pytorch"
|
474 |
+
|
475 |
+
# add version info in args
|
476 |
+
args.version = __version__
|
477 |
+
|
478 |
+
# logging info
|
479 |
+
if args.verbose > 0:
|
480 |
+
logging.basicConfig(
|
481 |
+
level=logging.INFO,
|
482 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
483 |
+
)
|
484 |
+
else:
|
485 |
+
logging.basicConfig(
|
486 |
+
level=logging.WARN,
|
487 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
488 |
+
)
|
489 |
+
logging.warning("Skip DEBUG/INFO messages")
|
490 |
+
|
491 |
+
# If --ngpu is not given,
|
492 |
+
# 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
|
493 |
+
# 2. if nvidia-smi exists, use all devices
|
494 |
+
# 3. else ngpu=0
|
495 |
+
if args.ngpu is None:
|
496 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
497 |
+
if cvd is not None:
|
498 |
+
ngpu = len(cvd.split(","))
|
499 |
+
else:
|
500 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
501 |
+
try:
|
502 |
+
p = subprocess.run(
|
503 |
+
["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
504 |
+
)
|
505 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
506 |
+
ngpu = 0
|
507 |
+
else:
|
508 |
+
ngpu = len(p.stderr.decode().split("\n")) - 1
|
509 |
+
args.ngpu = ngpu
|
510 |
+
else:
|
511 |
+
if is_torch_1_2_plus and args.ngpu != 1:
|
512 |
+
logging.debug(
|
513 |
+
"There are some bugs with multi-GPU processing in PyTorch 1.2+"
|
514 |
+
+ " (see https://github.com/pytorch/pytorch/issues/21108)"
|
515 |
+
)
|
516 |
+
ngpu = args.ngpu
|
517 |
+
logging.info(f"ngpu: {ngpu}")
|
518 |
+
|
519 |
+
# display PYTHONPATH
|
520 |
+
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
|
521 |
+
|
522 |
+
# set random seed
|
523 |
+
logging.info("random seed = %d" % args.seed)
|
524 |
+
random.seed(args.seed)
|
525 |
+
np.random.seed(args.seed)
|
526 |
+
|
527 |
+
# load dictionary for debug log
|
528 |
+
if args.dict is not None:
|
529 |
+
with open(args.dict, "rb") as f:
|
530 |
+
dictionary = f.readlines()
|
531 |
+
char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary]
|
532 |
+
char_list.insert(0, "<blank>")
|
533 |
+
char_list.append("<eos>")
|
534 |
+
args.char_list = char_list
|
535 |
+
else:
|
536 |
+
args.char_list = None
|
537 |
+
|
538 |
+
# train
|
539 |
+
logging.info("backend = " + args.backend)
|
540 |
+
|
541 |
+
if args.backend == "pytorch":
|
542 |
+
from espnet.st.pytorch_backend.st import train
|
543 |
+
|
544 |
+
train(args)
|
545 |
+
else:
|
546 |
+
raise ValueError("Only pytorch are supported.")
|
547 |
+
|
548 |
+
|
549 |
+
if __name__ == "__main__":
|
550 |
+
main(sys.argv[1:])
|
espnet/bin/st_trans.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# encoding: utf-8
|
3 |
+
|
4 |
+
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""End-to-end speech translation model decoding script."""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import random
|
12 |
+
import sys
|
13 |
+
|
14 |
+
import configargparse
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
|
18 |
+
# NOTE: you need this func to generate our sphinx doc
|
19 |
+
def get_parser():
|
20 |
+
"""Get default arguments."""
|
21 |
+
parser = configargparse.ArgumentParser(
|
22 |
+
description="Translate text from speech using a speech translation "
|
23 |
+
"model on one CPU or GPU",
|
24 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
25 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
26 |
+
)
|
27 |
+
# general configuration
|
28 |
+
parser.add("--config", is_config_file=True, help="Config file path")
|
29 |
+
parser.add(
|
30 |
+
"--config2",
|
31 |
+
is_config_file=True,
|
32 |
+
help="Second config file path that overwrites the settings in `--config`",
|
33 |
+
)
|
34 |
+
parser.add(
|
35 |
+
"--config3",
|
36 |
+
is_config_file=True,
|
37 |
+
help="Third config file path that overwrites "
|
38 |
+
"the settings in `--config` and `--config2`",
|
39 |
+
)
|
40 |
+
|
41 |
+
parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
|
42 |
+
parser.add_argument(
|
43 |
+
"--dtype",
|
44 |
+
choices=("float16", "float32", "float64"),
|
45 |
+
default="float32",
|
46 |
+
help="Float precision (only available in --api v2)",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--backend",
|
50 |
+
type=str,
|
51 |
+
default="chainer",
|
52 |
+
choices=["chainer", "pytorch"],
|
53 |
+
help="Backend library",
|
54 |
+
)
|
55 |
+
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
|
56 |
+
parser.add_argument("--seed", type=int, default=1, help="Random seed")
|
57 |
+
parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
|
58 |
+
parser.add_argument(
|
59 |
+
"--batchsize",
|
60 |
+
type=int,
|
61 |
+
default=1,
|
62 |
+
help="Batch size for beam search (0: means no batch processing)",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--preprocess-conf",
|
66 |
+
type=str,
|
67 |
+
default=None,
|
68 |
+
help="The configuration file for the pre-processing",
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--api",
|
72 |
+
default="v1",
|
73 |
+
choices=["v1", "v2"],
|
74 |
+
help="Beam search APIs "
|
75 |
+
"v1: Default API. "
|
76 |
+
"It only supports the ASRInterface.recognize method and DefaultRNNLM. "
|
77 |
+
"v2: Experimental API. "
|
78 |
+
"It supports any models that implements ScorerInterface.",
|
79 |
+
)
|
80 |
+
# task related
|
81 |
+
parser.add_argument(
|
82 |
+
"--trans-json", type=str, help="Filename of translation data (json)"
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--result-label",
|
86 |
+
type=str,
|
87 |
+
required=True,
|
88 |
+
help="Filename of result label data (json)",
|
89 |
+
)
|
90 |
+
# model (parameter) related
|
91 |
+
parser.add_argument(
|
92 |
+
"--model", type=str, required=True, help="Model file parameters to read"
|
93 |
+
)
|
94 |
+
# search related
|
95 |
+
parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
|
96 |
+
parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
|
97 |
+
parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty")
|
98 |
+
parser.add_argument(
|
99 |
+
"--maxlenratio",
|
100 |
+
type=float,
|
101 |
+
default=0.0,
|
102 |
+
help="""Input length ratio to obtain max output length.
|
103 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
104 |
+
to automatically find maximum hypothesis lengths""",
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--minlenratio",
|
108 |
+
type=float,
|
109 |
+
default=0.0,
|
110 |
+
help="Input length ratio to obtain min output length",
|
111 |
+
)
|
112 |
+
# multilingual related
|
113 |
+
parser.add_argument(
|
114 |
+
"--tgt-lang",
|
115 |
+
default=False,
|
116 |
+
type=str,
|
117 |
+
help="target language ID (e.g., <en>, <de>, and <fr> etc.)",
|
118 |
+
)
|
119 |
+
return parser
|
120 |
+
|
121 |
+
|
122 |
+
def main(args):
|
123 |
+
"""Run the main decoding function."""
|
124 |
+
parser = get_parser()
|
125 |
+
args = parser.parse_args(args)
|
126 |
+
|
127 |
+
# logging info
|
128 |
+
if args.verbose == 1:
|
129 |
+
logging.basicConfig(
|
130 |
+
level=logging.INFO,
|
131 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
132 |
+
)
|
133 |
+
elif args.verbose == 2:
|
134 |
+
logging.basicConfig(
|
135 |
+
level=logging.DEBUG,
|
136 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
137 |
+
)
|
138 |
+
else:
|
139 |
+
logging.basicConfig(
|
140 |
+
level=logging.WARN,
|
141 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
142 |
+
)
|
143 |
+
logging.warning("Skip DEBUG/INFO messages")
|
144 |
+
|
145 |
+
# check CUDA_VISIBLE_DEVICES
|
146 |
+
if args.ngpu > 0:
|
147 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
148 |
+
if cvd is None:
|
149 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
150 |
+
elif args.ngpu != len(cvd.split(",")):
|
151 |
+
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
|
152 |
+
sys.exit(1)
|
153 |
+
|
154 |
+
# TODO(mn5k): support of multiple GPUs
|
155 |
+
if args.ngpu > 1:
|
156 |
+
logging.error("The program only supports ngpu=1.")
|
157 |
+
sys.exit(1)
|
158 |
+
|
159 |
+
# display PYTHONPATH
|
160 |
+
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
|
161 |
+
|
162 |
+
# seed setting
|
163 |
+
random.seed(args.seed)
|
164 |
+
np.random.seed(args.seed)
|
165 |
+
logging.info("set random seed = %d" % args.seed)
|
166 |
+
|
167 |
+
# trans
|
168 |
+
logging.info("backend = " + args.backend)
|
169 |
+
if args.backend == "pytorch":
|
170 |
+
# Experimental API that supports custom LMs
|
171 |
+
from espnet.st.pytorch_backend.st import trans
|
172 |
+
|
173 |
+
if args.dtype != "float32":
|
174 |
+
raise NotImplementedError(
|
175 |
+
f"`--dtype {args.dtype}` is only available with `--api v2`"
|
176 |
+
)
|
177 |
+
trans(args)
|
178 |
+
else:
|
179 |
+
raise ValueError("Only pytorch are supported.")
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
main(sys.argv[1:])
|
espnet/bin/tts_decode.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2018 Nagoya University (Tomoki Hayashi)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
"""TTS decoding script."""
|
7 |
+
|
8 |
+
import configargparse
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import platform
|
12 |
+
import subprocess
|
13 |
+
import sys
|
14 |
+
|
15 |
+
from espnet.utils.cli_utils import strtobool
|
16 |
+
|
17 |
+
|
18 |
+
# NOTE: you need this func to generate our sphinx doc
|
19 |
+
def get_parser():
|
20 |
+
"""Get parser of decoding arguments."""
|
21 |
+
parser = configargparse.ArgumentParser(
|
22 |
+
description="Synthesize speech from text using a TTS model on one CPU",
|
23 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
24 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
25 |
+
)
|
26 |
+
# general configuration
|
27 |
+
parser.add("--config", is_config_file=True, help="config file path")
|
28 |
+
parser.add(
|
29 |
+
"--config2",
|
30 |
+
is_config_file=True,
|
31 |
+
help="second config file path that overwrites the settings in `--config`.",
|
32 |
+
)
|
33 |
+
parser.add(
|
34 |
+
"--config3",
|
35 |
+
is_config_file=True,
|
36 |
+
help="third config file path that overwrites "
|
37 |
+
"the settings in `--config` and `--config2`.",
|
38 |
+
)
|
39 |
+
|
40 |
+
parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs")
|
41 |
+
parser.add_argument(
|
42 |
+
"--backend",
|
43 |
+
default="pytorch",
|
44 |
+
type=str,
|
45 |
+
choices=["chainer", "pytorch"],
|
46 |
+
help="Backend library",
|
47 |
+
)
|
48 |
+
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
|
49 |
+
parser.add_argument("--seed", default=1, type=int, help="Random seed")
|
50 |
+
parser.add_argument("--out", type=str, required=True, help="Output filename")
|
51 |
+
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
|
52 |
+
parser.add_argument(
|
53 |
+
"--preprocess-conf",
|
54 |
+
type=str,
|
55 |
+
default=None,
|
56 |
+
help="The configuration file for the pre-processing",
|
57 |
+
)
|
58 |
+
# task related
|
59 |
+
parser.add_argument(
|
60 |
+
"--json", type=str, required=True, help="Filename of train label data (json)"
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--model", type=str, required=True, help="Model file parameters to read"
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--model-conf", type=str, default=None, help="Model config file"
|
67 |
+
)
|
68 |
+
# decoding related
|
69 |
+
parser.add_argument(
|
70 |
+
"--maxlenratio", type=float, default=5, help="Maximum length ratio in decoding"
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--minlenratio", type=float, default=0, help="Minimum length ratio in decoding"
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--threshold", type=float, default=0.5, help="Threshold value in decoding"
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--use-att-constraint",
|
80 |
+
type=strtobool,
|
81 |
+
default=False,
|
82 |
+
help="Whether to use the attention constraint",
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--backward-window",
|
86 |
+
type=int,
|
87 |
+
default=1,
|
88 |
+
help="Backward window size in the attention constraint",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--forward-window",
|
92 |
+
type=int,
|
93 |
+
default=3,
|
94 |
+
help="Forward window size in the attention constraint",
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--fastspeech-alpha",
|
98 |
+
type=float,
|
99 |
+
default=1.0,
|
100 |
+
help="Alpha to change the speed for FastSpeech",
|
101 |
+
)
|
102 |
+
# save related
|
103 |
+
parser.add_argument(
|
104 |
+
"--save-durations",
|
105 |
+
default=False,
|
106 |
+
type=strtobool,
|
107 |
+
help="Whether to save durations converted from attentions",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--save-focus-rates",
|
111 |
+
default=False,
|
112 |
+
type=strtobool,
|
113 |
+
help="Whether to save focus rates of attentions",
|
114 |
+
)
|
115 |
+
return parser
|
116 |
+
|
117 |
+
|
118 |
+
def main(args):
|
119 |
+
"""Run deocding."""
|
120 |
+
parser = get_parser()
|
121 |
+
args = parser.parse_args(args)
|
122 |
+
|
123 |
+
# logging info
|
124 |
+
if args.verbose > 0:
|
125 |
+
logging.basicConfig(
|
126 |
+
level=logging.INFO,
|
127 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
logging.basicConfig(
|
131 |
+
level=logging.WARN,
|
132 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
133 |
+
)
|
134 |
+
logging.warning("Skip DEBUG/INFO messages")
|
135 |
+
|
136 |
+
# check CUDA_VISIBLE_DEVICES
|
137 |
+
if args.ngpu > 0:
|
138 |
+
# python 2 case
|
139 |
+
if platform.python_version_tuple()[0] == "2":
|
140 |
+
if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]):
|
141 |
+
cvd = subprocess.check_output(
|
142 |
+
["/usr/local/bin/free-gpu", "-n", str(args.ngpu)]
|
143 |
+
).strip()
|
144 |
+
logging.info("CLSP: use gpu" + cvd)
|
145 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = cvd
|
146 |
+
# python 3 case
|
147 |
+
else:
|
148 |
+
if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]).decode():
|
149 |
+
cvd = (
|
150 |
+
subprocess.check_output(
|
151 |
+
["/usr/local/bin/free-gpu", "-n", str(args.ngpu)]
|
152 |
+
)
|
153 |
+
.decode()
|
154 |
+
.strip()
|
155 |
+
)
|
156 |
+
logging.info("CLSP: use gpu" + cvd)
|
157 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = cvd
|
158 |
+
|
159 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
160 |
+
if cvd is None:
|
161 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
162 |
+
elif args.ngpu != len(cvd.split(",")):
|
163 |
+
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
|
164 |
+
sys.exit(1)
|
165 |
+
|
166 |
+
# display PYTHONPATH
|
167 |
+
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
|
168 |
+
|
169 |
+
# extract
|
170 |
+
logging.info("backend = " + args.backend)
|
171 |
+
if args.backend == "pytorch":
|
172 |
+
from espnet.tts.pytorch_backend.tts import decode
|
173 |
+
|
174 |
+
decode(args)
|
175 |
+
else:
|
176 |
+
raise NotImplementedError("Only pytorch is supported.")
|
177 |
+
|
178 |
+
|
179 |
+
if __name__ == "__main__":
|
180 |
+
main(sys.argv[1:])
|
espnet/bin/tts_train.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2018 Nagoya University (Tomoki Hayashi)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
"""Text-to-speech model training script."""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import subprocess
|
12 |
+
import sys
|
13 |
+
|
14 |
+
import configargparse
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
from espnet import __version__
|
18 |
+
from espnet.nets.tts_interface import TTSInterface
|
19 |
+
from espnet.utils.cli_utils import strtobool
|
20 |
+
from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES
|
21 |
+
|
22 |
+
|
23 |
+
# NOTE: you need this func to generate our sphinx doc
|
24 |
+
def get_parser():
|
25 |
+
"""Get parser of training arguments."""
|
26 |
+
parser = configargparse.ArgumentParser(
|
27 |
+
description="Train a new text-to-speech (TTS) model on one CPU, "
|
28 |
+
"one or multiple GPUs",
|
29 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
30 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
31 |
+
)
|
32 |
+
|
33 |
+
# general configuration
|
34 |
+
parser.add("--config", is_config_file=True, help="config file path")
|
35 |
+
parser.add(
|
36 |
+
"--config2",
|
37 |
+
is_config_file=True,
|
38 |
+
help="second config file path that overwrites the settings in `--config`.",
|
39 |
+
)
|
40 |
+
parser.add(
|
41 |
+
"--config3",
|
42 |
+
is_config_file=True,
|
43 |
+
help="third config file path that overwrites "
|
44 |
+
"the settings in `--config` and `--config2`.",
|
45 |
+
)
|
46 |
+
|
47 |
+
parser.add_argument(
|
48 |
+
"--ngpu",
|
49 |
+
default=None,
|
50 |
+
type=int,
|
51 |
+
help="Number of GPUs. If not given, use all visible devices",
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--backend",
|
55 |
+
default="pytorch",
|
56 |
+
type=str,
|
57 |
+
choices=["chainer", "pytorch"],
|
58 |
+
help="Backend library",
|
59 |
+
)
|
60 |
+
parser.add_argument("--outdir", type=str, required=True, help="Output directory")
|
61 |
+
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
|
62 |
+
parser.add_argument("--seed", default=1, type=int, help="Random seed")
|
63 |
+
parser.add_argument(
|
64 |
+
"--resume",
|
65 |
+
"-r",
|
66 |
+
default="",
|
67 |
+
type=str,
|
68 |
+
nargs="?",
|
69 |
+
help="Resume the training from snapshot",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--minibatches",
|
73 |
+
"-N",
|
74 |
+
type=int,
|
75 |
+
default="-1",
|
76 |
+
help="Process only N minibatches (for debug)",
|
77 |
+
)
|
78 |
+
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
|
79 |
+
parser.add_argument(
|
80 |
+
"--tensorboard-dir",
|
81 |
+
default=None,
|
82 |
+
type=str,
|
83 |
+
nargs="?",
|
84 |
+
help="Tensorboard log directory path",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--eval-interval-epochs", default=1, type=int, help="Evaluation interval epochs"
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--save-interval-epochs", default=1, type=int, help="Save interval epochs"
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--report-interval-iters",
|
94 |
+
default=100,
|
95 |
+
type=int,
|
96 |
+
help="Report interval iterations",
|
97 |
+
)
|
98 |
+
# task related
|
99 |
+
parser.add_argument(
|
100 |
+
"--train-json", type=str, required=True, help="Filename of training json"
|
101 |
+
)
|
102 |
+
parser.add_argument(
|
103 |
+
"--valid-json", type=str, required=True, help="Filename of validation json"
|
104 |
+
)
|
105 |
+
# network architecture
|
106 |
+
parser.add_argument(
|
107 |
+
"--model-module",
|
108 |
+
type=str,
|
109 |
+
default="espnet.nets.pytorch_backend.e2e_tts_tacotron2:Tacotron2",
|
110 |
+
help="model defined module",
|
111 |
+
)
|
112 |
+
# minibatch related
|
113 |
+
parser.add_argument(
|
114 |
+
"--sortagrad",
|
115 |
+
default=0,
|
116 |
+
type=int,
|
117 |
+
nargs="?",
|
118 |
+
help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--batch-sort-key",
|
122 |
+
default="shuffle",
|
123 |
+
type=str,
|
124 |
+
choices=["shuffle", "output", "input"],
|
125 |
+
nargs="?",
|
126 |
+
help='Batch sorting key. "shuffle" only work with --batch-count "seq".',
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--batch-count",
|
130 |
+
default="auto",
|
131 |
+
choices=BATCH_COUNT_CHOICES,
|
132 |
+
help="How to count batch_size. "
|
133 |
+
"The default (auto) will find how to count by args.",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--batch-size",
|
137 |
+
"--batch-seqs",
|
138 |
+
"-b",
|
139 |
+
default=0,
|
140 |
+
type=int,
|
141 |
+
help="Maximum seqs in a minibatch (0 to disable)",
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--batch-bins",
|
145 |
+
default=0,
|
146 |
+
type=int,
|
147 |
+
help="Maximum bins in a minibatch (0 to disable)",
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--batch-frames-in",
|
151 |
+
default=0,
|
152 |
+
type=int,
|
153 |
+
help="Maximum input frames in a minibatch (0 to disable)",
|
154 |
+
)
|
155 |
+
parser.add_argument(
|
156 |
+
"--batch-frames-out",
|
157 |
+
default=0,
|
158 |
+
type=int,
|
159 |
+
help="Maximum output frames in a minibatch (0 to disable)",
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--batch-frames-inout",
|
163 |
+
default=0,
|
164 |
+
type=int,
|
165 |
+
help="Maximum input+output frames in a minibatch (0 to disable)",
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--maxlen-in",
|
169 |
+
"--batch-seq-maxlen-in",
|
170 |
+
default=100,
|
171 |
+
type=int,
|
172 |
+
metavar="ML",
|
173 |
+
help="When --batch-count=seq, "
|
174 |
+
"batch size is reduced if the input sequence length > ML.",
|
175 |
+
)
|
176 |
+
parser.add_argument(
|
177 |
+
"--maxlen-out",
|
178 |
+
"--batch-seq-maxlen-out",
|
179 |
+
default=200,
|
180 |
+
type=int,
|
181 |
+
metavar="ML",
|
182 |
+
help="When --batch-count=seq, "
|
183 |
+
"batch size is reduced if the output sequence length > ML",
|
184 |
+
)
|
185 |
+
parser.add_argument(
|
186 |
+
"--num-iter-processes",
|
187 |
+
default=0,
|
188 |
+
type=int,
|
189 |
+
help="Number of processes of iterator",
|
190 |
+
)
|
191 |
+
parser.add_argument(
|
192 |
+
"--preprocess-conf",
|
193 |
+
type=str,
|
194 |
+
default=None,
|
195 |
+
help="The configuration file for the pre-processing",
|
196 |
+
)
|
197 |
+
parser.add_argument(
|
198 |
+
"--use-speaker-embedding",
|
199 |
+
default=False,
|
200 |
+
type=strtobool,
|
201 |
+
help="Whether to use speaker embedding",
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--use-second-target",
|
205 |
+
default=False,
|
206 |
+
type=strtobool,
|
207 |
+
help="Whether to use second target",
|
208 |
+
)
|
209 |
+
# optimization related
|
210 |
+
parser.add_argument(
|
211 |
+
"--opt", default="adam", type=str, choices=["adam", "noam"], help="Optimizer"
|
212 |
+
)
|
213 |
+
parser.add_argument(
|
214 |
+
"--accum-grad", default=1, type=int, help="Number of gradient accumuration"
|
215 |
+
)
|
216 |
+
parser.add_argument(
|
217 |
+
"--lr", default=1e-3, type=float, help="Learning rate for optimizer"
|
218 |
+
)
|
219 |
+
parser.add_argument("--eps", default=1e-6, type=float, help="Epsilon for optimizer")
|
220 |
+
parser.add_argument(
|
221 |
+
"--weight-decay",
|
222 |
+
default=1e-6,
|
223 |
+
type=float,
|
224 |
+
help="Weight decay coefficient for optimizer",
|
225 |
+
)
|
226 |
+
parser.add_argument(
|
227 |
+
"--epochs", "-e", default=30, type=int, help="Number of maximum epochs"
|
228 |
+
)
|
229 |
+
parser.add_argument(
|
230 |
+
"--early-stop-criterion",
|
231 |
+
default="validation/main/loss",
|
232 |
+
type=str,
|
233 |
+
nargs="?",
|
234 |
+
help="Value to monitor to trigger an early stopping of the training",
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--patience",
|
238 |
+
default=3,
|
239 |
+
type=int,
|
240 |
+
nargs="?",
|
241 |
+
help="Number of epochs to wait "
|
242 |
+
"without improvement before stopping the training",
|
243 |
+
)
|
244 |
+
parser.add_argument(
|
245 |
+
"--grad-clip", default=1, type=float, help="Gradient norm threshold to clip"
|
246 |
+
)
|
247 |
+
parser.add_argument(
|
248 |
+
"--num-save-attention",
|
249 |
+
default=5,
|
250 |
+
type=int,
|
251 |
+
help="Number of samples of attention to be saved",
|
252 |
+
)
|
253 |
+
parser.add_argument(
|
254 |
+
"--keep-all-data-on-mem",
|
255 |
+
default=False,
|
256 |
+
type=strtobool,
|
257 |
+
help="Whether to keep all data on memory",
|
258 |
+
)
|
259 |
+
# finetuning related
|
260 |
+
parser.add_argument(
|
261 |
+
"--enc-init",
|
262 |
+
default=None,
|
263 |
+
type=str,
|
264 |
+
help="Pre-trained TTS model path to initialize encoder.",
|
265 |
+
)
|
266 |
+
parser.add_argument(
|
267 |
+
"--enc-init-mods",
|
268 |
+
default="enc.",
|
269 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
270 |
+
help="List of encoder modules to initialize, separated by a comma.",
|
271 |
+
)
|
272 |
+
parser.add_argument(
|
273 |
+
"--dec-init",
|
274 |
+
default=None,
|
275 |
+
type=str,
|
276 |
+
help="Pre-trained TTS model path to initialize decoder.",
|
277 |
+
)
|
278 |
+
parser.add_argument(
|
279 |
+
"--dec-init-mods",
|
280 |
+
default="dec.",
|
281 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
282 |
+
help="List of decoder modules to initialize, separated by a comma.",
|
283 |
+
)
|
284 |
+
parser.add_argument(
|
285 |
+
"--freeze-mods",
|
286 |
+
default=None,
|
287 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
288 |
+
help="List of modules to freeze (not to train), separated by a comma.",
|
289 |
+
)
|
290 |
+
|
291 |
+
return parser
|
292 |
+
|
293 |
+
|
294 |
+
def main(cmd_args):
|
295 |
+
"""Run training."""
|
296 |
+
parser = get_parser()
|
297 |
+
args, _ = parser.parse_known_args(cmd_args)
|
298 |
+
|
299 |
+
from espnet.utils.dynamic_import import dynamic_import
|
300 |
+
|
301 |
+
model_class = dynamic_import(args.model_module)
|
302 |
+
assert issubclass(model_class, TTSInterface)
|
303 |
+
model_class.add_arguments(parser)
|
304 |
+
args = parser.parse_args(cmd_args)
|
305 |
+
|
306 |
+
# add version info in args
|
307 |
+
args.version = __version__
|
308 |
+
|
309 |
+
# logging info
|
310 |
+
if args.verbose > 0:
|
311 |
+
logging.basicConfig(
|
312 |
+
level=logging.INFO,
|
313 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
logging.basicConfig(
|
317 |
+
level=logging.WARN,
|
318 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
319 |
+
)
|
320 |
+
logging.warning("Skip DEBUG/INFO messages")
|
321 |
+
|
322 |
+
# If --ngpu is not given,
|
323 |
+
# 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
|
324 |
+
# 2. if nvidia-smi exists, use all devices
|
325 |
+
# 3. else ngpu=0
|
326 |
+
if args.ngpu is None:
|
327 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
328 |
+
if cvd is not None:
|
329 |
+
ngpu = len(cvd.split(","))
|
330 |
+
else:
|
331 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
332 |
+
try:
|
333 |
+
p = subprocess.run(
|
334 |
+
["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
335 |
+
)
|
336 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
337 |
+
ngpu = 0
|
338 |
+
else:
|
339 |
+
ngpu = len(p.stderr.decode().split("\n")) - 1
|
340 |
+
args.ngpu = ngpu
|
341 |
+
else:
|
342 |
+
ngpu = args.ngpu
|
343 |
+
logging.info(f"ngpu: {ngpu}")
|
344 |
+
|
345 |
+
# set random seed
|
346 |
+
logging.info("random seed = %d" % args.seed)
|
347 |
+
random.seed(args.seed)
|
348 |
+
np.random.seed(args.seed)
|
349 |
+
|
350 |
+
if args.backend == "pytorch":
|
351 |
+
from espnet.tts.pytorch_backend.tts import train
|
352 |
+
|
353 |
+
train(args)
|
354 |
+
else:
|
355 |
+
raise NotImplementedError("Only pytorch is supported.")
|
356 |
+
|
357 |
+
|
358 |
+
if __name__ == "__main__":
|
359 |
+
main(sys.argv[1:])
|
espnet/bin/vc_decode.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2020 Nagoya University (Wen-Chin Huang)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
"""VC decoding script."""
|
7 |
+
|
8 |
+
import configargparse
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import platform
|
12 |
+
import subprocess
|
13 |
+
import sys
|
14 |
+
|
15 |
+
from espnet.utils.cli_utils import strtobool
|
16 |
+
|
17 |
+
|
18 |
+
# NOTE: you need this func to generate our sphinx doc
|
19 |
+
def get_parser():
|
20 |
+
"""Get parser of decoding arguments."""
|
21 |
+
parser = configargparse.ArgumentParser(
|
22 |
+
description="Converting speech using a VC model on one CPU",
|
23 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
24 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
25 |
+
)
|
26 |
+
# general configuration
|
27 |
+
parser.add("--config", is_config_file=True, help="config file path")
|
28 |
+
parser.add(
|
29 |
+
"--config2",
|
30 |
+
is_config_file=True,
|
31 |
+
help="second config file path that overwrites the settings in `--config`.",
|
32 |
+
)
|
33 |
+
parser.add(
|
34 |
+
"--config3",
|
35 |
+
is_config_file=True,
|
36 |
+
help="third config file path that overwrites the settings "
|
37 |
+
"in `--config` and `--config2`.",
|
38 |
+
)
|
39 |
+
|
40 |
+
parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs")
|
41 |
+
parser.add_argument(
|
42 |
+
"--backend",
|
43 |
+
default="pytorch",
|
44 |
+
type=str,
|
45 |
+
choices=["chainer", "pytorch"],
|
46 |
+
help="Backend library",
|
47 |
+
)
|
48 |
+
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
|
49 |
+
parser.add_argument("--seed", default=1, type=int, help="Random seed")
|
50 |
+
parser.add_argument("--out", type=str, required=True, help="Output filename")
|
51 |
+
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
|
52 |
+
parser.add_argument(
|
53 |
+
"--preprocess-conf",
|
54 |
+
type=str,
|
55 |
+
default=None,
|
56 |
+
help="The configuration file for the pre-processing",
|
57 |
+
)
|
58 |
+
# task related
|
59 |
+
parser.add_argument(
|
60 |
+
"--json", type=str, required=True, help="Filename of train label data (json)"
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--model", type=str, required=True, help="Model file parameters to read"
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--model-conf", type=str, default=None, help="Model config file"
|
67 |
+
)
|
68 |
+
# decoding related
|
69 |
+
parser.add_argument(
|
70 |
+
"--maxlenratio", type=float, default=5, help="Maximum length ratio in decoding"
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--minlenratio", type=float, default=0, help="Minimum length ratio in decoding"
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--threshold", type=float, default=0.5, help="Threshold value in decoding"
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--use-att-constraint",
|
80 |
+
type=strtobool,
|
81 |
+
default=False,
|
82 |
+
help="Whether to use the attention constraint",
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--backward-window",
|
86 |
+
type=int,
|
87 |
+
default=1,
|
88 |
+
help="Backward window size in the attention constraint",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--forward-window",
|
92 |
+
type=int,
|
93 |
+
default=3,
|
94 |
+
help="Forward window size in the attention constraint",
|
95 |
+
)
|
96 |
+
# save related
|
97 |
+
parser.add_argument(
|
98 |
+
"--save-durations",
|
99 |
+
default=False,
|
100 |
+
type=strtobool,
|
101 |
+
help="Whether to save durations converted from attentions",
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--save-focus-rates",
|
105 |
+
default=False,
|
106 |
+
type=strtobool,
|
107 |
+
help="Whether to save focus rates of attentions",
|
108 |
+
)
|
109 |
+
return parser
|
110 |
+
|
111 |
+
|
112 |
+
def main(args):
|
113 |
+
"""Run deocding."""
|
114 |
+
parser = get_parser()
|
115 |
+
args = parser.parse_args(args)
|
116 |
+
|
117 |
+
# logging info
|
118 |
+
if args.verbose > 0:
|
119 |
+
logging.basicConfig(
|
120 |
+
level=logging.INFO,
|
121 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
logging.basicConfig(
|
125 |
+
level=logging.WARN,
|
126 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
127 |
+
)
|
128 |
+
logging.warning("Skip DEBUG/INFO messages")
|
129 |
+
|
130 |
+
# check CUDA_VISIBLE_DEVICES
|
131 |
+
if args.ngpu > 0:
|
132 |
+
# python 2 case
|
133 |
+
if platform.python_version_tuple()[0] == "2":
|
134 |
+
if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]):
|
135 |
+
cvd = subprocess.check_output(
|
136 |
+
["/usr/local/bin/free-gpu", "-n", str(args.ngpu)]
|
137 |
+
).strip()
|
138 |
+
logging.info("CLSP: use gpu" + cvd)
|
139 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = cvd
|
140 |
+
# python 3 case
|
141 |
+
else:
|
142 |
+
if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]).decode():
|
143 |
+
cvd = (
|
144 |
+
subprocess.check_output(
|
145 |
+
["/usr/local/bin/free-gpu", "-n", str(args.ngpu)]
|
146 |
+
)
|
147 |
+
.decode()
|
148 |
+
.strip()
|
149 |
+
)
|
150 |
+
logging.info("CLSP: use gpu" + cvd)
|
151 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = cvd
|
152 |
+
|
153 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
154 |
+
if cvd is None:
|
155 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
156 |
+
elif args.ngpu != len(cvd.split(",")):
|
157 |
+
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
|
158 |
+
sys.exit(1)
|
159 |
+
|
160 |
+
# display PYTHONPATH
|
161 |
+
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
|
162 |
+
|
163 |
+
# extract
|
164 |
+
logging.info("backend = " + args.backend)
|
165 |
+
if args.backend == "pytorch":
|
166 |
+
from espnet.vc.pytorch_backend.vc import decode
|
167 |
+
|
168 |
+
decode(args)
|
169 |
+
else:
|
170 |
+
raise NotImplementedError("Only pytorch is supported.")
|
171 |
+
|
172 |
+
|
173 |
+
if __name__ == "__main__":
|
174 |
+
main(sys.argv[1:])
|
espnet/bin/vc_train.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2020 Nagoya University (Wen-Chin Huang)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
"""Voice conversion model training script."""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import subprocess
|
12 |
+
import sys
|
13 |
+
|
14 |
+
import configargparse
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
from espnet import __version__
|
18 |
+
from espnet.nets.tts_interface import TTSInterface
|
19 |
+
from espnet.utils.cli_utils import strtobool
|
20 |
+
from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES
|
21 |
+
|
22 |
+
|
23 |
+
# NOTE: you need this func to generate our sphinx doc
|
24 |
+
def get_parser():
|
25 |
+
"""Get parser of training arguments."""
|
26 |
+
parser = configargparse.ArgumentParser(
|
27 |
+
description="Train a new voice conversion (VC) model on one CPU, "
|
28 |
+
"one or multiple GPUs",
|
29 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
30 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
31 |
+
)
|
32 |
+
|
33 |
+
# general configuration
|
34 |
+
parser.add("--config", is_config_file=True, help="config file path")
|
35 |
+
parser.add(
|
36 |
+
"--config2",
|
37 |
+
is_config_file=True,
|
38 |
+
help="second config file path that overwrites the settings in `--config`.",
|
39 |
+
)
|
40 |
+
parser.add(
|
41 |
+
"--config3",
|
42 |
+
is_config_file=True,
|
43 |
+
help="third config file path that overwrites the settings "
|
44 |
+
"in `--config` and `--config2`.",
|
45 |
+
)
|
46 |
+
|
47 |
+
parser.add_argument(
|
48 |
+
"--ngpu",
|
49 |
+
default=None,
|
50 |
+
type=int,
|
51 |
+
help="Number of GPUs. If not given, use all visible devices",
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--backend",
|
55 |
+
default="pytorch",
|
56 |
+
type=str,
|
57 |
+
choices=["chainer", "pytorch"],
|
58 |
+
help="Backend library",
|
59 |
+
)
|
60 |
+
parser.add_argument("--outdir", type=str, required=True, help="Output directory")
|
61 |
+
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
|
62 |
+
parser.add_argument("--seed", default=1, type=int, help="Random seed")
|
63 |
+
parser.add_argument(
|
64 |
+
"--resume",
|
65 |
+
"-r",
|
66 |
+
default="",
|
67 |
+
type=str,
|
68 |
+
nargs="?",
|
69 |
+
help="Resume the training from snapshot",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--minibatches",
|
73 |
+
"-N",
|
74 |
+
type=int,
|
75 |
+
default="-1",
|
76 |
+
help="Process only N minibatches (for debug)",
|
77 |
+
)
|
78 |
+
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
|
79 |
+
parser.add_argument(
|
80 |
+
"--tensorboard-dir",
|
81 |
+
default=None,
|
82 |
+
type=str,
|
83 |
+
nargs="?",
|
84 |
+
help="Tensorboard log directory path",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--eval-interval-epochs",
|
88 |
+
default=100,
|
89 |
+
type=int,
|
90 |
+
help="Evaluation interval epochs",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--save-interval-epochs", default=1, type=int, help="Save interval epochs"
|
94 |
+
)
|
95 |
+
parser.add_argument(
|
96 |
+
"--report-interval-iters",
|
97 |
+
default=10,
|
98 |
+
type=int,
|
99 |
+
help="Report interval iterations",
|
100 |
+
)
|
101 |
+
# task related
|
102 |
+
parser.add_argument("--srcspk", type=str, help="Source speaker")
|
103 |
+
parser.add_argument("--trgspk", type=str, help="Target speaker")
|
104 |
+
parser.add_argument(
|
105 |
+
"--train-json", type=str, required=True, help="Filename of training json"
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--valid-json", type=str, required=True, help="Filename of validation json"
|
109 |
+
)
|
110 |
+
|
111 |
+
# network architecture
|
112 |
+
parser.add_argument(
|
113 |
+
"--model-module",
|
114 |
+
type=str,
|
115 |
+
default="espnet.nets.pytorch_backend.e2e_tts_tacotron2:Tacotron2",
|
116 |
+
help="model defined module",
|
117 |
+
)
|
118 |
+
# minibatch related
|
119 |
+
parser.add_argument(
|
120 |
+
"--sortagrad",
|
121 |
+
default=0,
|
122 |
+
type=int,
|
123 |
+
nargs="?",
|
124 |
+
help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--batch-sort-key",
|
128 |
+
default="shuffle",
|
129 |
+
type=str,
|
130 |
+
choices=["shuffle", "output", "input"],
|
131 |
+
nargs="?",
|
132 |
+
help='Batch sorting key. "shuffle" only work with --batch-count "seq".',
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"--batch-count",
|
136 |
+
default="auto",
|
137 |
+
choices=BATCH_COUNT_CHOICES,
|
138 |
+
help="How to count batch_size. "
|
139 |
+
"The default (auto) will find how to count by args.",
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--batch-size",
|
143 |
+
"--batch-seqs",
|
144 |
+
"-b",
|
145 |
+
default=0,
|
146 |
+
type=int,
|
147 |
+
help="Maximum seqs in a minibatch (0 to disable)",
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--batch-bins",
|
151 |
+
default=0,
|
152 |
+
type=int,
|
153 |
+
help="Maximum bins in a minibatch (0 to disable)",
|
154 |
+
)
|
155 |
+
parser.add_argument(
|
156 |
+
"--batch-frames-in",
|
157 |
+
default=0,
|
158 |
+
type=int,
|
159 |
+
help="Maximum input frames in a minibatch (0 to disable)",
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--batch-frames-out",
|
163 |
+
default=0,
|
164 |
+
type=int,
|
165 |
+
help="Maximum output frames in a minibatch (0 to disable)",
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--batch-frames-inout",
|
169 |
+
default=0,
|
170 |
+
type=int,
|
171 |
+
help="Maximum input+output frames in a minibatch (0 to disable)",
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--maxlen-in",
|
175 |
+
"--batch-seq-maxlen-in",
|
176 |
+
default=100,
|
177 |
+
type=int,
|
178 |
+
metavar="ML",
|
179 |
+
help="When --batch-count=seq, "
|
180 |
+
"batch size is reduced if the input sequence length > ML.",
|
181 |
+
)
|
182 |
+
parser.add_argument(
|
183 |
+
"--maxlen-out",
|
184 |
+
"--batch-seq-maxlen-out",
|
185 |
+
default=200,
|
186 |
+
type=int,
|
187 |
+
metavar="ML",
|
188 |
+
help="When --batch-count=seq, "
|
189 |
+
"batch size is reduced if the output sequence length > ML",
|
190 |
+
)
|
191 |
+
parser.add_argument(
|
192 |
+
"--num-iter-processes",
|
193 |
+
default=0,
|
194 |
+
type=int,
|
195 |
+
help="Number of processes of iterator",
|
196 |
+
)
|
197 |
+
parser.add_argument(
|
198 |
+
"--preprocess-conf",
|
199 |
+
type=str,
|
200 |
+
default=None,
|
201 |
+
help="The configuration file for the pre-processing",
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--use-speaker-embedding",
|
205 |
+
default=False,
|
206 |
+
type=strtobool,
|
207 |
+
help="Whether to use speaker embedding",
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--use-second-target",
|
211 |
+
default=False,
|
212 |
+
type=strtobool,
|
213 |
+
help="Whether to use second target",
|
214 |
+
)
|
215 |
+
# optimization related
|
216 |
+
parser.add_argument(
|
217 |
+
"--opt",
|
218 |
+
default="adam",
|
219 |
+
type=str,
|
220 |
+
choices=["adam", "noam", "lamb"],
|
221 |
+
help="Optimizer",
|
222 |
+
)
|
223 |
+
parser.add_argument(
|
224 |
+
"--accum-grad", default=1, type=int, help="Number of gradient accumuration"
|
225 |
+
)
|
226 |
+
parser.add_argument(
|
227 |
+
"--lr", default=1e-3, type=float, help="Learning rate for optimizer"
|
228 |
+
)
|
229 |
+
parser.add_argument("--eps", default=1e-6, type=float, help="Epsilon for optimizer")
|
230 |
+
parser.add_argument(
|
231 |
+
"--weight-decay",
|
232 |
+
default=1e-6,
|
233 |
+
type=float,
|
234 |
+
help="Weight decay coefficient for optimizer",
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--epochs", "-e", default=30, type=int, help="Number of maximum epochs"
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--early-stop-criterion",
|
241 |
+
default="validation/main/loss",
|
242 |
+
type=str,
|
243 |
+
nargs="?",
|
244 |
+
help="Value to monitor to trigger an early stopping of the training",
|
245 |
+
)
|
246 |
+
parser.add_argument(
|
247 |
+
"--patience",
|
248 |
+
default=3,
|
249 |
+
type=int,
|
250 |
+
nargs="?",
|
251 |
+
help="Number of epochs to wait without improvement "
|
252 |
+
"before stopping the training",
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--grad-clip", default=1, type=float, help="Gradient norm threshold to clip"
|
256 |
+
)
|
257 |
+
parser.add_argument(
|
258 |
+
"--num-save-attention",
|
259 |
+
default=5,
|
260 |
+
type=int,
|
261 |
+
help="Number of samples of attention to be saved",
|
262 |
+
)
|
263 |
+
parser.add_argument(
|
264 |
+
"--keep-all-data-on-mem",
|
265 |
+
default=False,
|
266 |
+
type=strtobool,
|
267 |
+
help="Whether to keep all data on memory",
|
268 |
+
)
|
269 |
+
|
270 |
+
parser.add_argument(
|
271 |
+
"--enc-init",
|
272 |
+
default=None,
|
273 |
+
type=str,
|
274 |
+
help="Pre-trained model path to initialize encoder.",
|
275 |
+
)
|
276 |
+
parser.add_argument(
|
277 |
+
"--enc-init-mods",
|
278 |
+
default="enc.",
|
279 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
280 |
+
help="List of encoder modules to initialize, separated by a comma.",
|
281 |
+
)
|
282 |
+
parser.add_argument(
|
283 |
+
"--dec-init",
|
284 |
+
default=None,
|
285 |
+
type=str,
|
286 |
+
help="Pre-trained model path to initialize decoder.",
|
287 |
+
)
|
288 |
+
parser.add_argument(
|
289 |
+
"--dec-init-mods",
|
290 |
+
default="dec.",
|
291 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
292 |
+
help="List of decoder modules to initialize, separated by a comma.",
|
293 |
+
)
|
294 |
+
parser.add_argument(
|
295 |
+
"--freeze-mods",
|
296 |
+
default=None,
|
297 |
+
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
|
298 |
+
help="List of modules to freeze (not to train), separated by a comma.",
|
299 |
+
)
|
300 |
+
|
301 |
+
return parser
|
302 |
+
|
303 |
+
|
304 |
+
def main(cmd_args):
|
305 |
+
"""Run training."""
|
306 |
+
parser = get_parser()
|
307 |
+
args, _ = parser.parse_known_args(cmd_args)
|
308 |
+
|
309 |
+
from espnet.utils.dynamic_import import dynamic_import
|
310 |
+
|
311 |
+
model_class = dynamic_import(args.model_module)
|
312 |
+
assert issubclass(model_class, TTSInterface)
|
313 |
+
model_class.add_arguments(parser)
|
314 |
+
args = parser.parse_args(cmd_args)
|
315 |
+
|
316 |
+
# add version info in args
|
317 |
+
args.version = __version__
|
318 |
+
|
319 |
+
# logging info
|
320 |
+
if args.verbose > 0:
|
321 |
+
logging.basicConfig(
|
322 |
+
level=logging.INFO,
|
323 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
324 |
+
)
|
325 |
+
else:
|
326 |
+
logging.basicConfig(
|
327 |
+
level=logging.WARN,
|
328 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
329 |
+
)
|
330 |
+
logging.warning("Skip DEBUG/INFO messages")
|
331 |
+
|
332 |
+
# If --ngpu is not given,
|
333 |
+
# 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
|
334 |
+
# 2. if nvidia-smi exists, use all devices
|
335 |
+
# 3. else ngpu=0
|
336 |
+
if args.ngpu is None:
|
337 |
+
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
338 |
+
if cvd is not None:
|
339 |
+
ngpu = len(cvd.split(","))
|
340 |
+
else:
|
341 |
+
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
342 |
+
try:
|
343 |
+
p = subprocess.run(
|
344 |
+
["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
345 |
+
)
|
346 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
347 |
+
ngpu = 0
|
348 |
+
else:
|
349 |
+
ngpu = len(p.stderr.decode().split("\n")) - 1
|
350 |
+
else:
|
351 |
+
ngpu = args.ngpu
|
352 |
+
logging.info(f"ngpu: {ngpu}")
|
353 |
+
|
354 |
+
# set random seed
|
355 |
+
logging.info("random seed = %d" % args.seed)
|
356 |
+
random.seed(args.seed)
|
357 |
+
np.random.seed(args.seed)
|
358 |
+
|
359 |
+
if args.backend == "pytorch":
|
360 |
+
from espnet.vc.pytorch_backend.vc import train
|
361 |
+
|
362 |
+
train(args)
|
363 |
+
else:
|
364 |
+
raise NotImplementedError("Only pytorch is supported.")
|
365 |
+
|
366 |
+
|
367 |
+
if __name__ == "__main__":
|
368 |
+
main(sys.argv[1:])
|
espnet/lm/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/lm/chainer_backend/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/lm/chainer_backend/extlm.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2018 Mitsubishi Electric Research Laboratories (Takaaki Hori)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import chainer
|
10 |
+
import chainer.functions as F
|
11 |
+
from espnet.lm.lm_utils import make_lexical_tree
|
12 |
+
|
13 |
+
|
14 |
+
# Definition of a multi-level (subword/word) language model
|
15 |
+
class MultiLevelLM(chainer.Chain):
|
16 |
+
logzero = -10000000000.0
|
17 |
+
zero = 1.0e-10
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
wordlm,
|
22 |
+
subwordlm,
|
23 |
+
word_dict,
|
24 |
+
subword_dict,
|
25 |
+
subwordlm_weight=0.8,
|
26 |
+
oov_penalty=1.0,
|
27 |
+
open_vocab=True,
|
28 |
+
):
|
29 |
+
super(MultiLevelLM, self).__init__()
|
30 |
+
self.wordlm = wordlm
|
31 |
+
self.subwordlm = subwordlm
|
32 |
+
self.word_eos = word_dict["<eos>"]
|
33 |
+
self.word_unk = word_dict["<unk>"]
|
34 |
+
self.xp_word_eos = self.xp.full(1, self.word_eos, "i")
|
35 |
+
self.xp_word_unk = self.xp.full(1, self.word_unk, "i")
|
36 |
+
self.space = subword_dict["<space>"]
|
37 |
+
self.eos = subword_dict["<eos>"]
|
38 |
+
self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk)
|
39 |
+
self.log_oov_penalty = math.log(oov_penalty)
|
40 |
+
self.open_vocab = open_vocab
|
41 |
+
self.subword_dict_size = len(subword_dict)
|
42 |
+
self.subwordlm_weight = subwordlm_weight
|
43 |
+
self.normalized = True
|
44 |
+
|
45 |
+
def __call__(self, state, x):
|
46 |
+
# update state with input label x
|
47 |
+
if state is None: # make initial states and log-prob vectors
|
48 |
+
wlm_state, z_wlm = self.wordlm(None, self.xp_word_eos)
|
49 |
+
wlm_logprobs = F.log_softmax(z_wlm).data
|
50 |
+
clm_state, z_clm = self.subwordlm(None, x)
|
51 |
+
log_y = F.log_softmax(z_clm).data * self.subwordlm_weight
|
52 |
+
new_node = self.lexroot
|
53 |
+
clm_logprob = 0.0
|
54 |
+
xi = self.space
|
55 |
+
else:
|
56 |
+
clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
|
57 |
+
xi = int(x)
|
58 |
+
if xi == self.space: # inter-word transition
|
59 |
+
if node is not None and node[1] >= 0: # check if the node is word end
|
60 |
+
w = self.xp.full(1, node[1], "i")
|
61 |
+
else: # this node is not a word end, which means <unk>
|
62 |
+
w = self.xp_word_unk
|
63 |
+
# update wordlm state and log-prob vector
|
64 |
+
wlm_state, z_wlm = self.wordlm(wlm_state, w)
|
65 |
+
wlm_logprobs = F.log_softmax(z_wlm).data
|
66 |
+
new_node = self.lexroot # move to the tree root
|
67 |
+
clm_logprob = 0.0
|
68 |
+
elif node is not None and xi in node[0]: # intra-word transition
|
69 |
+
new_node = node[0][xi]
|
70 |
+
clm_logprob += log_y[0, xi]
|
71 |
+
elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
|
72 |
+
new_node = None
|
73 |
+
clm_logprob += log_y[0, xi]
|
74 |
+
else: # if open_vocab flag is disabled, return 0 probabilities
|
75 |
+
log_y = self.xp.full((1, self.subword_dict_size), self.logzero, "f")
|
76 |
+
return (clm_state, wlm_state, None, log_y, 0.0), log_y
|
77 |
+
|
78 |
+
clm_state, z_clm = self.subwordlm(clm_state, x)
|
79 |
+
log_y = F.log_softmax(z_clm).data * self.subwordlm_weight
|
80 |
+
|
81 |
+
# apply word-level probabilies for <space> and <eos> labels
|
82 |
+
if xi != self.space:
|
83 |
+
if new_node is not None and new_node[1] >= 0: # if new node is word end
|
84 |
+
wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob
|
85 |
+
else:
|
86 |
+
wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty
|
87 |
+
log_y[:, self.space] = wlm_logprob
|
88 |
+
log_y[:, self.eos] = wlm_logprob
|
89 |
+
else:
|
90 |
+
log_y[:, self.space] = self.logzero
|
91 |
+
log_y[:, self.eos] = self.logzero
|
92 |
+
|
93 |
+
return (clm_state, wlm_state, wlm_logprobs, new_node, log_y, clm_logprob), log_y
|
94 |
+
|
95 |
+
def final(self, state):
|
96 |
+
clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
|
97 |
+
if node is not None and node[1] >= 0: # check if the node is word end
|
98 |
+
w = self.xp.full(1, node[1], "i")
|
99 |
+
else: # this node is not a word end, which means <unk>
|
100 |
+
w = self.xp_word_unk
|
101 |
+
wlm_state, z_wlm = self.wordlm(wlm_state, w)
|
102 |
+
return F.log_softmax(z_wlm).data[:, self.word_eos]
|
103 |
+
|
104 |
+
|
105 |
+
# Definition of a look-ahead word language model
|
106 |
+
class LookAheadWordLM(chainer.Chain):
|
107 |
+
logzero = -10000000000.0
|
108 |
+
zero = 1.0e-10
|
109 |
+
|
110 |
+
def __init__(
|
111 |
+
self, wordlm, word_dict, subword_dict, oov_penalty=0.0001, open_vocab=True
|
112 |
+
):
|
113 |
+
super(LookAheadWordLM, self).__init__()
|
114 |
+
self.wordlm = wordlm
|
115 |
+
self.word_eos = word_dict["<eos>"]
|
116 |
+
self.word_unk = word_dict["<unk>"]
|
117 |
+
self.xp_word_eos = self.xp.full(1, self.word_eos, "i")
|
118 |
+
self.xp_word_unk = self.xp.full(1, self.word_unk, "i")
|
119 |
+
self.space = subword_dict["<space>"]
|
120 |
+
self.eos = subword_dict["<eos>"]
|
121 |
+
self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk)
|
122 |
+
self.oov_penalty = oov_penalty
|
123 |
+
self.open_vocab = open_vocab
|
124 |
+
self.subword_dict_size = len(subword_dict)
|
125 |
+
self.normalized = True
|
126 |
+
|
127 |
+
def __call__(self, state, x):
|
128 |
+
# update state with input label x
|
129 |
+
if state is None: # make initial states and cumlative probability vector
|
130 |
+
wlm_state, z_wlm = self.wordlm(None, self.xp_word_eos)
|
131 |
+
cumsum_probs = self.xp.cumsum(F.softmax(z_wlm).data, axis=1)
|
132 |
+
new_node = self.lexroot
|
133 |
+
xi = self.space
|
134 |
+
else:
|
135 |
+
wlm_state, cumsum_probs, node = state
|
136 |
+
xi = int(x)
|
137 |
+
if xi == self.space: # inter-word transition
|
138 |
+
if node is not None and node[1] >= 0: # check if the node is word end
|
139 |
+
w = self.xp.full(1, node[1], "i")
|
140 |
+
else: # this node is not a word end, which means <unk>
|
141 |
+
w = self.xp_word_unk
|
142 |
+
# update wordlm state and cumlative probability vector
|
143 |
+
wlm_state, z_wlm = self.wordlm(wlm_state, w)
|
144 |
+
cumsum_probs = self.xp.cumsum(F.softmax(z_wlm).data, axis=1)
|
145 |
+
new_node = self.lexroot # move to the tree root
|
146 |
+
elif node is not None and xi in node[0]: # intra-word transition
|
147 |
+
new_node = node[0][xi]
|
148 |
+
elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
|
149 |
+
new_node = None
|
150 |
+
else: # if open_vocab flag is disabled, return 0 probabilities
|
151 |
+
log_y = self.xp.full((1, self.subword_dict_size), self.logzero, "f")
|
152 |
+
return (wlm_state, None, None), log_y
|
153 |
+
|
154 |
+
if new_node is not None:
|
155 |
+
succ, wid, wids = new_node
|
156 |
+
# compute parent node probability
|
157 |
+
sum_prob = (
|
158 |
+
(cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]])
|
159 |
+
if wids is not None
|
160 |
+
else 1.0
|
161 |
+
)
|
162 |
+
if sum_prob < self.zero:
|
163 |
+
log_y = self.xp.full((1, self.subword_dict_size), self.logzero, "f")
|
164 |
+
return (wlm_state, cumsum_probs, new_node), log_y
|
165 |
+
# set <unk> probability as a default value
|
166 |
+
unk_prob = (
|
167 |
+
cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1]
|
168 |
+
)
|
169 |
+
y = self.xp.full(
|
170 |
+
(1, self.subword_dict_size), unk_prob * self.oov_penalty, "f"
|
171 |
+
)
|
172 |
+
# compute transition probabilities to child nodes
|
173 |
+
for cid, nd in succ.items():
|
174 |
+
y[:, cid] = (
|
175 |
+
cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]]
|
176 |
+
) / sum_prob
|
177 |
+
# apply word-level probabilies for <space> and <eos> labels
|
178 |
+
if wid >= 0:
|
179 |
+
wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob
|
180 |
+
y[:, self.space] = wlm_prob
|
181 |
+
y[:, self.eos] = wlm_prob
|
182 |
+
elif xi == self.space:
|
183 |
+
y[:, self.space] = self.zero
|
184 |
+
y[:, self.eos] = self.zero
|
185 |
+
log_y = self.xp.log(
|
186 |
+
self.xp.clip(y, self.zero, None)
|
187 |
+
) # clip to avoid log(0)
|
188 |
+
else: # if no path in the tree, transition probability is one
|
189 |
+
log_y = self.xp.zeros((1, self.subword_dict_size), "f")
|
190 |
+
return (wlm_state, cumsum_probs, new_node), log_y
|
191 |
+
|
192 |
+
def final(self, state):
|
193 |
+
wlm_state, cumsum_probs, node = state
|
194 |
+
if node is not None and node[1] >= 0: # check if the node is word end
|
195 |
+
w = self.xp.full(1, node[1], "i")
|
196 |
+
else: # this node is not a word end, which means <unk>
|
197 |
+
w = self.xp_word_unk
|
198 |
+
wlm_state, z_wlm = self.wordlm(wlm_state, w)
|
199 |
+
return F.log_softmax(z_wlm).data[:, self.word_eos]
|
espnet/lm/chainer_backend/lm.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
# This code is ported from the following implementation written in Torch.
|
7 |
+
# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
|
8 |
+
|
9 |
+
|
10 |
+
import copy
|
11 |
+
import json
|
12 |
+
import logging
|
13 |
+
import numpy as np
|
14 |
+
import six
|
15 |
+
|
16 |
+
import chainer
|
17 |
+
from chainer.dataset import convert
|
18 |
+
import chainer.functions as F
|
19 |
+
import chainer.links as L
|
20 |
+
|
21 |
+
# for classifier link
|
22 |
+
from chainer.functions.loss import softmax_cross_entropy
|
23 |
+
from chainer import link
|
24 |
+
from chainer import reporter
|
25 |
+
from chainer import training
|
26 |
+
from chainer.training import extensions
|
27 |
+
|
28 |
+
from espnet.lm.lm_utils import compute_perplexity
|
29 |
+
from espnet.lm.lm_utils import count_tokens
|
30 |
+
from espnet.lm.lm_utils import MakeSymlinkToBestModel
|
31 |
+
from espnet.lm.lm_utils import ParallelSentenceIterator
|
32 |
+
from espnet.lm.lm_utils import read_tokens
|
33 |
+
|
34 |
+
import espnet.nets.chainer_backend.deterministic_embed_id as DL
|
35 |
+
from espnet.nets.lm_interface import LMInterface
|
36 |
+
from espnet.optimizer.factory import dynamic_import_optimizer
|
37 |
+
from espnet.scheduler.chainer import ChainerScheduler
|
38 |
+
from espnet.scheduler.scheduler import dynamic_import_scheduler
|
39 |
+
|
40 |
+
from espnet.utils.training.tensorboard_logger import TensorboardLogger
|
41 |
+
from tensorboardX import SummaryWriter
|
42 |
+
|
43 |
+
from espnet.utils.deterministic_utils import set_deterministic_chainer
|
44 |
+
from espnet.utils.training.evaluator import BaseEvaluator
|
45 |
+
from espnet.utils.training.iterators import ShufflingEnabler
|
46 |
+
from espnet.utils.training.train_utils import check_early_stop
|
47 |
+
from espnet.utils.training.train_utils import set_early_stop
|
48 |
+
|
49 |
+
|
50 |
+
# TODO(karita): reimplement RNNLM with new interface
|
51 |
+
class DefaultRNNLM(LMInterface, link.Chain):
|
52 |
+
"""Default RNNLM wrapper to compute reduce framewise loss values.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
n_vocab (int): The size of the vocabulary
|
56 |
+
args (argparse.Namespace): configurations. see `add_arguments`
|
57 |
+
"""
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def add_arguments(parser):
|
61 |
+
parser.add_argument(
|
62 |
+
"--type",
|
63 |
+
type=str,
|
64 |
+
default="lstm",
|
65 |
+
nargs="?",
|
66 |
+
choices=["lstm", "gru"],
|
67 |
+
help="Which type of RNN to use",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--layer", "-l", type=int, default=2, help="Number of hidden layers"
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--unit", "-u", type=int, default=650, help="Number of hidden units"
|
74 |
+
)
|
75 |
+
return parser
|
76 |
+
|
77 |
+
|
78 |
+
class ClassifierWithState(link.Chain):
|
79 |
+
"""A wrapper for a chainer RNNLM
|
80 |
+
|
81 |
+
:param link.Chain predictor : The RNNLM
|
82 |
+
:param function lossfun: The loss function to use
|
83 |
+
:param int/str label_key:
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
predictor,
|
89 |
+
lossfun=softmax_cross_entropy.softmax_cross_entropy,
|
90 |
+
label_key=-1,
|
91 |
+
):
|
92 |
+
if not (isinstance(label_key, (int, str))):
|
93 |
+
raise TypeError("label_key must be int or str, but is %s" % type(label_key))
|
94 |
+
|
95 |
+
super(ClassifierWithState, self).__init__()
|
96 |
+
self.lossfun = lossfun
|
97 |
+
self.y = None
|
98 |
+
self.loss = None
|
99 |
+
self.label_key = label_key
|
100 |
+
|
101 |
+
with self.init_scope():
|
102 |
+
self.predictor = predictor
|
103 |
+
|
104 |
+
def __call__(self, state, *args, **kwargs):
|
105 |
+
"""Computes the loss value for an input and label pair.
|
106 |
+
|
107 |
+
It also computes accuracy and stores it to the attribute.
|
108 |
+
When ``label_key`` is ``int``, the corresponding element in ``args``
|
109 |
+
is treated as ground truth labels. And when it is ``str``, the
|
110 |
+
element in ``kwargs`` is used.
|
111 |
+
The all elements of ``args`` and ``kwargs`` except the groundtruth
|
112 |
+
labels are features.
|
113 |
+
It feeds features to the predictor and compare the result
|
114 |
+
with ground truth labels.
|
115 |
+
|
116 |
+
:param state : The LM state
|
117 |
+
:param list[chainer.Variable] args : Input minibatch
|
118 |
+
:param dict[chainer.Variable] kwargs : Input minibatch
|
119 |
+
:return loss value
|
120 |
+
:rtype chainer.Variable
|
121 |
+
"""
|
122 |
+
|
123 |
+
if isinstance(self.label_key, int):
|
124 |
+
if not (-len(args) <= self.label_key < len(args)):
|
125 |
+
msg = "Label key %d is out of bounds" % self.label_key
|
126 |
+
raise ValueError(msg)
|
127 |
+
t = args[self.label_key]
|
128 |
+
if self.label_key == -1:
|
129 |
+
args = args[:-1]
|
130 |
+
else:
|
131 |
+
args = args[: self.label_key] + args[self.label_key + 1 :]
|
132 |
+
elif isinstance(self.label_key, str):
|
133 |
+
if self.label_key not in kwargs:
|
134 |
+
msg = 'Label key "%s" is not found' % self.label_key
|
135 |
+
raise ValueError(msg)
|
136 |
+
t = kwargs[self.label_key]
|
137 |
+
del kwargs[self.label_key]
|
138 |
+
|
139 |
+
self.y = None
|
140 |
+
self.loss = None
|
141 |
+
state, self.y = self.predictor(state, *args, **kwargs)
|
142 |
+
self.loss = self.lossfun(self.y, t)
|
143 |
+
return state, self.loss
|
144 |
+
|
145 |
+
def predict(self, state, x):
|
146 |
+
"""Predict log probabilities for given state and input x using the predictor
|
147 |
+
|
148 |
+
:param state : the state
|
149 |
+
:param x : the input
|
150 |
+
:return a tuple (state, log prob vector)
|
151 |
+
:rtype cupy/numpy array
|
152 |
+
"""
|
153 |
+
if hasattr(self.predictor, "normalized") and self.predictor.normalized:
|
154 |
+
return self.predictor(state, x)
|
155 |
+
else:
|
156 |
+
state, z = self.predictor(state, x)
|
157 |
+
return state, F.log_softmax(z).data
|
158 |
+
|
159 |
+
def final(self, state):
|
160 |
+
"""Predict final log probabilities for given state using the predictor
|
161 |
+
|
162 |
+
:param state : the state
|
163 |
+
:return log probability vector
|
164 |
+
:rtype cupy/numpy array
|
165 |
+
|
166 |
+
"""
|
167 |
+
if hasattr(self.predictor, "final"):
|
168 |
+
return self.predictor.final(state)
|
169 |
+
else:
|
170 |
+
return 0.0
|
171 |
+
|
172 |
+
|
173 |
+
# Definition of a recurrent net for language modeling
|
174 |
+
class RNNLM(chainer.Chain):
|
175 |
+
"""A chainer RNNLM
|
176 |
+
|
177 |
+
:param int n_vocab: The size of the vocabulary
|
178 |
+
:param int n_layers: The number of layers to create
|
179 |
+
:param int n_units: The number of units per layer
|
180 |
+
:param str type: The RNN type
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, n_vocab, n_layers, n_units, typ="lstm"):
|
184 |
+
super(RNNLM, self).__init__()
|
185 |
+
with self.init_scope():
|
186 |
+
self.embed = DL.EmbedID(n_vocab, n_units)
|
187 |
+
self.rnn = (
|
188 |
+
chainer.ChainList(
|
189 |
+
*[L.StatelessLSTM(n_units, n_units) for _ in range(n_layers)]
|
190 |
+
)
|
191 |
+
if typ == "lstm"
|
192 |
+
else chainer.ChainList(
|
193 |
+
*[L.StatelessGRU(n_units, n_units) for _ in range(n_layers)]
|
194 |
+
)
|
195 |
+
)
|
196 |
+
self.lo = L.Linear(n_units, n_vocab)
|
197 |
+
|
198 |
+
for param in self.params():
|
199 |
+
param.data[...] = np.random.uniform(-0.1, 0.1, param.data.shape)
|
200 |
+
self.n_layers = n_layers
|
201 |
+
self.n_units = n_units
|
202 |
+
self.typ = typ
|
203 |
+
|
204 |
+
def __call__(self, state, x):
|
205 |
+
if state is None:
|
206 |
+
if self.typ == "lstm":
|
207 |
+
state = {"c": [None] * self.n_layers, "h": [None] * self.n_layers}
|
208 |
+
else:
|
209 |
+
state = {"h": [None] * self.n_layers}
|
210 |
+
|
211 |
+
h = [None] * self.n_layers
|
212 |
+
emb = self.embed(x)
|
213 |
+
if self.typ == "lstm":
|
214 |
+
c = [None] * self.n_layers
|
215 |
+
c[0], h[0] = self.rnn[0](state["c"][0], state["h"][0], F.dropout(emb))
|
216 |
+
for n in six.moves.range(1, self.n_layers):
|
217 |
+
c[n], h[n] = self.rnn[n](
|
218 |
+
state["c"][n], state["h"][n], F.dropout(h[n - 1])
|
219 |
+
)
|
220 |
+
state = {"c": c, "h": h}
|
221 |
+
else:
|
222 |
+
if state["h"][0] is None:
|
223 |
+
xp = self.xp
|
224 |
+
with chainer.backends.cuda.get_device_from_id(self._device_id):
|
225 |
+
state["h"][0] = chainer.Variable(
|
226 |
+
xp.zeros((emb.shape[0], self.n_units), dtype=emb.dtype)
|
227 |
+
)
|
228 |
+
h[0] = self.rnn[0](state["h"][0], F.dropout(emb))
|
229 |
+
for n in six.moves.range(1, self.n_layers):
|
230 |
+
if state["h"][n] is None:
|
231 |
+
xp = self.xp
|
232 |
+
with chainer.backends.cuda.get_device_from_id(self._device_id):
|
233 |
+
state["h"][n] = chainer.Variable(
|
234 |
+
xp.zeros(
|
235 |
+
(h[n - 1].shape[0], self.n_units), dtype=h[n - 1].dtype
|
236 |
+
)
|
237 |
+
)
|
238 |
+
h[n] = self.rnn[n](state["h"][n], F.dropout(h[n - 1]))
|
239 |
+
state = {"h": h}
|
240 |
+
y = self.lo(F.dropout(h[-1]))
|
241 |
+
return state, y
|
242 |
+
|
243 |
+
|
244 |
+
class BPTTUpdater(training.updaters.StandardUpdater):
|
245 |
+
"""An updater for a chainer LM
|
246 |
+
|
247 |
+
:param chainer.dataset.Iterator train_iter : The train iterator
|
248 |
+
:param optimizer:
|
249 |
+
:param schedulers:
|
250 |
+
:param int device : The device id
|
251 |
+
:param int accum_grad :
|
252 |
+
"""
|
253 |
+
|
254 |
+
def __init__(self, train_iter, optimizer, schedulers, device, accum_grad):
|
255 |
+
super(BPTTUpdater, self).__init__(train_iter, optimizer, device=device)
|
256 |
+
self.scheduler = ChainerScheduler(schedulers, optimizer)
|
257 |
+
self.accum_grad = accum_grad
|
258 |
+
|
259 |
+
# The core part of the update routine can be customized by overriding.
|
260 |
+
def update_core(self):
|
261 |
+
# When we pass one iterator and optimizer to StandardUpdater.__init__,
|
262 |
+
# they are automatically named 'main'.
|
263 |
+
train_iter = self.get_iterator("main")
|
264 |
+
optimizer = self.get_optimizer("main")
|
265 |
+
|
266 |
+
count = 0
|
267 |
+
sum_loss = 0
|
268 |
+
optimizer.target.cleargrads() # Clear the parameter gradients
|
269 |
+
for _ in range(self.accum_grad):
|
270 |
+
# Progress the dataset iterator for sentences at each iteration.
|
271 |
+
batch = train_iter.__next__()
|
272 |
+
x, t = convert.concat_examples(batch, device=self.device, padding=(0, -1))
|
273 |
+
# Concatenate the token IDs to matrices and send them to the device
|
274 |
+
# self.converter does this job
|
275 |
+
# (it is chainer.dataset.concat_examples by default)
|
276 |
+
xp = chainer.backends.cuda.get_array_module(x)
|
277 |
+
loss = 0
|
278 |
+
state = None
|
279 |
+
batch_size, sequence_length = x.shape
|
280 |
+
for i in six.moves.range(sequence_length):
|
281 |
+
# Compute the loss at this time step and accumulate it
|
282 |
+
state, loss_batch = optimizer.target(
|
283 |
+
state, chainer.Variable(x[:, i]), chainer.Variable(t[:, i])
|
284 |
+
)
|
285 |
+
non_zeros = xp.count_nonzero(x[:, i])
|
286 |
+
loss += loss_batch * non_zeros
|
287 |
+
count += int(non_zeros)
|
288 |
+
# backward
|
289 |
+
loss /= batch_size * self.accum_grad # normalized by batch size
|
290 |
+
sum_loss += float(loss.data)
|
291 |
+
loss.backward() # Backprop
|
292 |
+
loss.unchain_backward() # Truncate the graph
|
293 |
+
|
294 |
+
reporter.report({"loss": sum_loss}, optimizer.target)
|
295 |
+
reporter.report({"count": count}, optimizer.target)
|
296 |
+
# update
|
297 |
+
optimizer.update() # Update the parameters
|
298 |
+
self.scheduler.step(self.iteration)
|
299 |
+
|
300 |
+
|
301 |
+
class LMEvaluator(BaseEvaluator):
|
302 |
+
"""A custom evaluator for a chainer LM
|
303 |
+
|
304 |
+
:param chainer.dataset.Iterator val_iter : The validation iterator
|
305 |
+
:param eval_model : The model to evaluate
|
306 |
+
:param int device : The device id to use
|
307 |
+
"""
|
308 |
+
|
309 |
+
def __init__(self, val_iter, eval_model, device):
|
310 |
+
super(LMEvaluator, self).__init__(val_iter, eval_model, device=device)
|
311 |
+
|
312 |
+
def evaluate(self):
|
313 |
+
val_iter = self.get_iterator("main")
|
314 |
+
target = self.get_target("main")
|
315 |
+
loss = 0
|
316 |
+
count = 0
|
317 |
+
for batch in copy.copy(val_iter):
|
318 |
+
x, t = convert.concat_examples(batch, device=self.device, padding=(0, -1))
|
319 |
+
xp = chainer.backends.cuda.get_array_module(x)
|
320 |
+
state = None
|
321 |
+
for i in six.moves.range(len(x[0])):
|
322 |
+
state, loss_batch = target(state, x[:, i], t[:, i])
|
323 |
+
non_zeros = xp.count_nonzero(x[:, i])
|
324 |
+
loss += loss_batch.data * non_zeros
|
325 |
+
count += int(non_zeros)
|
326 |
+
# report validation loss
|
327 |
+
observation = {}
|
328 |
+
with reporter.report_scope(observation):
|
329 |
+
reporter.report({"loss": float(loss / count)}, target)
|
330 |
+
return observation
|
331 |
+
|
332 |
+
|
333 |
+
def train(args):
|
334 |
+
"""Train with the given args
|
335 |
+
|
336 |
+
:param Namespace args: The program arguments
|
337 |
+
"""
|
338 |
+
# TODO(karita): support this
|
339 |
+
if args.model_module != "default":
|
340 |
+
raise NotImplementedError("chainer backend does not support --model-module")
|
341 |
+
|
342 |
+
# display chainer version
|
343 |
+
logging.info("chainer version = " + chainer.__version__)
|
344 |
+
|
345 |
+
set_deterministic_chainer(args)
|
346 |
+
|
347 |
+
# check cuda and cudnn availability
|
348 |
+
if not chainer.cuda.available:
|
349 |
+
logging.warning("cuda is not available")
|
350 |
+
if not chainer.cuda.cudnn_enabled:
|
351 |
+
logging.warning("cudnn is not available")
|
352 |
+
|
353 |
+
# get special label ids
|
354 |
+
unk = args.char_list_dict["<unk>"]
|
355 |
+
eos = args.char_list_dict["<eos>"]
|
356 |
+
# read tokens as a sequence of sentences
|
357 |
+
train = read_tokens(args.train_label, args.char_list_dict)
|
358 |
+
val = read_tokens(args.valid_label, args.char_list_dict)
|
359 |
+
# count tokens
|
360 |
+
n_train_tokens, n_train_oovs = count_tokens(train, unk)
|
361 |
+
n_val_tokens, n_val_oovs = count_tokens(val, unk)
|
362 |
+
logging.info("#vocab = " + str(args.n_vocab))
|
363 |
+
logging.info("#sentences in the training data = " + str(len(train)))
|
364 |
+
logging.info("#tokens in the training data = " + str(n_train_tokens))
|
365 |
+
logging.info(
|
366 |
+
"oov rate in the training data = %.2f %%"
|
367 |
+
% (n_train_oovs / n_train_tokens * 100)
|
368 |
+
)
|
369 |
+
logging.info("#sentences in the validation data = " + str(len(val)))
|
370 |
+
logging.info("#tokens in the validation data = " + str(n_val_tokens))
|
371 |
+
logging.info(
|
372 |
+
"oov rate in the validation data = %.2f %%" % (n_val_oovs / n_val_tokens * 100)
|
373 |
+
)
|
374 |
+
|
375 |
+
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
|
376 |
+
|
377 |
+
# Create the dataset iterators
|
378 |
+
train_iter = ParallelSentenceIterator(
|
379 |
+
train,
|
380 |
+
args.batchsize,
|
381 |
+
max_length=args.maxlen,
|
382 |
+
sos=eos,
|
383 |
+
eos=eos,
|
384 |
+
shuffle=not use_sortagrad,
|
385 |
+
)
|
386 |
+
val_iter = ParallelSentenceIterator(
|
387 |
+
val, args.batchsize, max_length=args.maxlen, sos=eos, eos=eos, repeat=False
|
388 |
+
)
|
389 |
+
epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad)
|
390 |
+
logging.info("#iterations per epoch = %d" % epoch_iters)
|
391 |
+
logging.info("#total iterations = " + str(args.epoch * epoch_iters))
|
392 |
+
# Prepare an RNNLM model
|
393 |
+
rnn = RNNLM(args.n_vocab, args.layer, args.unit, args.type)
|
394 |
+
model = ClassifierWithState(rnn)
|
395 |
+
if args.ngpu > 1:
|
396 |
+
logging.warning("currently, multi-gpu is not supported. use single gpu.")
|
397 |
+
if args.ngpu > 0:
|
398 |
+
# Make the specified GPU current
|
399 |
+
gpu_id = 0
|
400 |
+
chainer.cuda.get_device_from_id(gpu_id).use()
|
401 |
+
model.to_gpu()
|
402 |
+
else:
|
403 |
+
gpu_id = -1
|
404 |
+
|
405 |
+
# Save model conf to json
|
406 |
+
model_conf = args.outdir + "/model.json"
|
407 |
+
with open(model_conf, "wb") as f:
|
408 |
+
logging.info("writing a model config file to " + model_conf)
|
409 |
+
f.write(
|
410 |
+
json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode(
|
411 |
+
"utf_8"
|
412 |
+
)
|
413 |
+
)
|
414 |
+
|
415 |
+
# Set up an optimizer
|
416 |
+
opt_class = dynamic_import_optimizer(args.opt, args.backend)
|
417 |
+
optimizer = opt_class.from_args(model, args)
|
418 |
+
if args.schedulers is None:
|
419 |
+
schedulers = []
|
420 |
+
else:
|
421 |
+
schedulers = [dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers]
|
422 |
+
|
423 |
+
optimizer.setup(model)
|
424 |
+
optimizer.add_hook(chainer.optimizer.GradientClipping(args.gradclip))
|
425 |
+
|
426 |
+
updater = BPTTUpdater(train_iter, optimizer, schedulers, gpu_id, args.accum_grad)
|
427 |
+
trainer = training.Trainer(updater, (args.epoch, "epoch"), out=args.outdir)
|
428 |
+
trainer.extend(LMEvaluator(val_iter, model, device=gpu_id))
|
429 |
+
trainer.extend(
|
430 |
+
extensions.LogReport(
|
431 |
+
postprocess=compute_perplexity,
|
432 |
+
trigger=(args.report_interval_iters, "iteration"),
|
433 |
+
)
|
434 |
+
)
|
435 |
+
trainer.extend(
|
436 |
+
extensions.PrintReport(
|
437 |
+
["epoch", "iteration", "perplexity", "val_perplexity", "elapsed_time"]
|
438 |
+
),
|
439 |
+
trigger=(args.report_interval_iters, "iteration"),
|
440 |
+
)
|
441 |
+
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
|
442 |
+
trainer.extend(extensions.snapshot(filename="snapshot.ep.{.updater.epoch}"))
|
443 |
+
trainer.extend(extensions.snapshot_object(model, "rnnlm.model.{.updater.epoch}"))
|
444 |
+
# MEMO(Hori): wants to use MinValueTrigger, but it seems to fail in resuming
|
445 |
+
trainer.extend(MakeSymlinkToBestModel("validation/main/loss", "rnnlm.model"))
|
446 |
+
|
447 |
+
if use_sortagrad:
|
448 |
+
trainer.extend(
|
449 |
+
ShufflingEnabler([train_iter]),
|
450 |
+
trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, "epoch"),
|
451 |
+
)
|
452 |
+
|
453 |
+
if args.resume:
|
454 |
+
logging.info("resumed from %s" % args.resume)
|
455 |
+
chainer.serializers.load_npz(args.resume, trainer)
|
456 |
+
|
457 |
+
set_early_stop(trainer, args, is_lm=True)
|
458 |
+
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
|
459 |
+
writer = SummaryWriter(args.tensorboard_dir)
|
460 |
+
trainer.extend(
|
461 |
+
TensorboardLogger(writer), trigger=(args.report_interval_iters, "iteration")
|
462 |
+
)
|
463 |
+
|
464 |
+
trainer.run()
|
465 |
+
check_early_stop(trainer, args.epoch)
|
466 |
+
|
467 |
+
# compute perplexity for test set
|
468 |
+
if args.test_label:
|
469 |
+
logging.info("test the best model")
|
470 |
+
chainer.serializers.load_npz(args.outdir + "/rnnlm.model.best", model)
|
471 |
+
test = read_tokens(args.test_label, args.char_list_dict)
|
472 |
+
n_test_tokens, n_test_oovs = count_tokens(test, unk)
|
473 |
+
logging.info("#sentences in the test data = " + str(len(test)))
|
474 |
+
logging.info("#tokens in the test data = " + str(n_test_tokens))
|
475 |
+
logging.info(
|
476 |
+
"oov rate in the test data = %.2f %%" % (n_test_oovs / n_test_tokens * 100)
|
477 |
+
)
|
478 |
+
test_iter = ParallelSentenceIterator(
|
479 |
+
test, args.batchsize, max_length=args.maxlen, sos=eos, eos=eos, repeat=False
|
480 |
+
)
|
481 |
+
evaluator = LMEvaluator(test_iter, model, device=gpu_id)
|
482 |
+
with chainer.using_config("train", False):
|
483 |
+
result = evaluator()
|
484 |
+
logging.info("test perplexity: " + str(np.exp(float(result["main/loss"]))))
|
espnet/lm/lm_utils.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
# This code is ported from the following implementation written in Torch.
|
7 |
+
# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
|
8 |
+
|
9 |
+
import chainer
|
10 |
+
import h5py
|
11 |
+
import logging
|
12 |
+
import numpy as np
|
13 |
+
import os
|
14 |
+
import random
|
15 |
+
import six
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from chainer.training import extension
|
19 |
+
|
20 |
+
|
21 |
+
def load_dataset(path, label_dict, outdir=None):
|
22 |
+
"""Load and save HDF5 that contains a dataset and stats for LM
|
23 |
+
|
24 |
+
Args:
|
25 |
+
path (str): The path of an input text dataset file
|
26 |
+
label_dict (dict[str, int]):
|
27 |
+
dictionary that maps token label string to its ID number
|
28 |
+
outdir (str): The path of an output dir
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
tuple[list[np.ndarray], int, int]: Tuple of
|
32 |
+
token IDs in np.int32 converted by `read_tokens`
|
33 |
+
the number of tokens by `count_tokens`,
|
34 |
+
and the number of OOVs by `count_tokens`
|
35 |
+
"""
|
36 |
+
if outdir is not None:
|
37 |
+
os.makedirs(outdir, exist_ok=True)
|
38 |
+
filename = outdir + "/" + os.path.basename(path) + ".h5"
|
39 |
+
if os.path.exists(filename):
|
40 |
+
logging.info(f"loading binary dataset: {filename}")
|
41 |
+
f = h5py.File(filename, "r")
|
42 |
+
return f["data"][:], f["n_tokens"][()], f["n_oovs"][()]
|
43 |
+
else:
|
44 |
+
logging.info("skip dump/load HDF5 because the output dir is not specified")
|
45 |
+
logging.info(f"reading text dataset: {path}")
|
46 |
+
ret = read_tokens(path, label_dict)
|
47 |
+
n_tokens, n_oovs = count_tokens(ret, label_dict["<unk>"])
|
48 |
+
if outdir is not None:
|
49 |
+
logging.info(f"saving binary dataset: {filename}")
|
50 |
+
with h5py.File(filename, "w") as f:
|
51 |
+
# http://docs.h5py.org/en/stable/special.html#arbitrary-vlen-data
|
52 |
+
data = f.create_dataset(
|
53 |
+
"data", (len(ret),), dtype=h5py.special_dtype(vlen=np.int32)
|
54 |
+
)
|
55 |
+
data[:] = ret
|
56 |
+
f["n_tokens"] = n_tokens
|
57 |
+
f["n_oovs"] = n_oovs
|
58 |
+
return ret, n_tokens, n_oovs
|
59 |
+
|
60 |
+
|
61 |
+
def read_tokens(filename, label_dict):
|
62 |
+
"""Read tokens as a sequence of sentences
|
63 |
+
|
64 |
+
:param str filename : The name of the input file
|
65 |
+
:param dict label_dict : dictionary that maps token label string to its ID number
|
66 |
+
:return list of ID sequences
|
67 |
+
:rtype list
|
68 |
+
"""
|
69 |
+
|
70 |
+
data = []
|
71 |
+
unk = label_dict["<unk>"]
|
72 |
+
for ln in tqdm(open(filename, "r", encoding="utf-8")):
|
73 |
+
data.append(
|
74 |
+
np.array(
|
75 |
+
[label_dict.get(label, unk) for label in ln.split()], dtype=np.int32
|
76 |
+
)
|
77 |
+
)
|
78 |
+
return data
|
79 |
+
|
80 |
+
|
81 |
+
def count_tokens(data, unk_id=None):
|
82 |
+
"""Count tokens and oovs in token ID sequences.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
data (list[np.ndarray]): list of token ID sequences
|
86 |
+
unk_id (int): ID of unknown token
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
tuple: tuple of number of token occurrences and number of oov tokens
|
90 |
+
|
91 |
+
"""
|
92 |
+
|
93 |
+
n_tokens = 0
|
94 |
+
n_oovs = 0
|
95 |
+
for sentence in data:
|
96 |
+
n_tokens += len(sentence)
|
97 |
+
if unk_id is not None:
|
98 |
+
n_oovs += np.count_nonzero(sentence == unk_id)
|
99 |
+
return n_tokens, n_oovs
|
100 |
+
|
101 |
+
|
102 |
+
def compute_perplexity(result):
|
103 |
+
"""Computes and add the perplexity to the LogReport
|
104 |
+
|
105 |
+
:param dict result: The current observations
|
106 |
+
"""
|
107 |
+
# Routine to rewrite the result dictionary of LogReport to add perplexity values
|
108 |
+
result["perplexity"] = np.exp(result["main/loss"] / result["main/count"])
|
109 |
+
if "validation/main/loss" in result:
|
110 |
+
result["val_perplexity"] = np.exp(result["validation/main/loss"])
|
111 |
+
|
112 |
+
|
113 |
+
class ParallelSentenceIterator(chainer.dataset.Iterator):
|
114 |
+
"""Dataset iterator to create a batch of sentences.
|
115 |
+
|
116 |
+
This iterator returns a pair of sentences, where one token is shifted
|
117 |
+
between the sentences like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
|
118 |
+
Sentence batches are made in order of longer sentences, and then
|
119 |
+
randomly shuffled.
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self, dataset, batch_size, max_length=0, sos=0, eos=0, repeat=True, shuffle=True
|
124 |
+
):
|
125 |
+
self.dataset = dataset
|
126 |
+
self.batch_size = batch_size # batch size
|
127 |
+
# Number of completed sweeps over the dataset. In this case, it is
|
128 |
+
# incremented if every word is visited at least once after the last
|
129 |
+
# increment.
|
130 |
+
self.epoch = 0
|
131 |
+
# True if the epoch is incremented at the last iteration.
|
132 |
+
self.is_new_epoch = False
|
133 |
+
self.repeat = repeat
|
134 |
+
length = len(dataset)
|
135 |
+
self.batch_indices = []
|
136 |
+
# make mini-batches
|
137 |
+
if batch_size > 1:
|
138 |
+
indices = sorted(range(len(dataset)), key=lambda i: -len(dataset[i]))
|
139 |
+
bs = 0
|
140 |
+
while bs < length:
|
141 |
+
be = min(bs + batch_size, length)
|
142 |
+
# batch size is automatically reduced if the sentence length
|
143 |
+
# is larger than max_length
|
144 |
+
if max_length > 0:
|
145 |
+
sent_length = len(dataset[indices[bs]])
|
146 |
+
be = min(
|
147 |
+
be, bs + max(batch_size // (sent_length // max_length + 1), 1)
|
148 |
+
)
|
149 |
+
self.batch_indices.append(np.array(indices[bs:be]))
|
150 |
+
bs = be
|
151 |
+
if shuffle:
|
152 |
+
# shuffle batches
|
153 |
+
random.shuffle(self.batch_indices)
|
154 |
+
else:
|
155 |
+
self.batch_indices = [np.array([i]) for i in six.moves.range(length)]
|
156 |
+
|
157 |
+
# NOTE: this is not a count of parameter updates. It is just a count of
|
158 |
+
# calls of ``__next__``.
|
159 |
+
self.iteration = 0
|
160 |
+
self.sos = sos
|
161 |
+
self.eos = eos
|
162 |
+
# use -1 instead of None internally
|
163 |
+
self._previous_epoch_detail = -1.0
|
164 |
+
|
165 |
+
def __next__(self):
|
166 |
+
# This iterator returns a list representing a mini-batch. Each item
|
167 |
+
# indicates a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
|
168 |
+
# represented by token IDs.
|
169 |
+
n_batches = len(self.batch_indices)
|
170 |
+
if not self.repeat and self.iteration >= n_batches:
|
171 |
+
# If not self.repeat, this iterator stops at the end of the first
|
172 |
+
# epoch (i.e., when all words are visited once).
|
173 |
+
raise StopIteration
|
174 |
+
|
175 |
+
batch = []
|
176 |
+
for idx in self.batch_indices[self.iteration % n_batches]:
|
177 |
+
batch.append(
|
178 |
+
(
|
179 |
+
np.append([self.sos], self.dataset[idx]),
|
180 |
+
np.append(self.dataset[idx], [self.eos]),
|
181 |
+
)
|
182 |
+
)
|
183 |
+
|
184 |
+
self._previous_epoch_detail = self.epoch_detail
|
185 |
+
self.iteration += 1
|
186 |
+
|
187 |
+
epoch = self.iteration // n_batches
|
188 |
+
self.is_new_epoch = self.epoch < epoch
|
189 |
+
if self.is_new_epoch:
|
190 |
+
self.epoch = epoch
|
191 |
+
|
192 |
+
return batch
|
193 |
+
|
194 |
+
def start_shuffle(self):
|
195 |
+
random.shuffle(self.batch_indices)
|
196 |
+
|
197 |
+
@property
|
198 |
+
def epoch_detail(self):
|
199 |
+
# Floating point version of epoch.
|
200 |
+
return self.iteration / len(self.batch_indices)
|
201 |
+
|
202 |
+
@property
|
203 |
+
def previous_epoch_detail(self):
|
204 |
+
if self._previous_epoch_detail < 0:
|
205 |
+
return None
|
206 |
+
return self._previous_epoch_detail
|
207 |
+
|
208 |
+
def serialize(self, serializer):
|
209 |
+
# It is important to serialize the state to be recovered on resume.
|
210 |
+
self.iteration = serializer("iteration", self.iteration)
|
211 |
+
self.epoch = serializer("epoch", self.epoch)
|
212 |
+
try:
|
213 |
+
self._previous_epoch_detail = serializer(
|
214 |
+
"previous_epoch_detail", self._previous_epoch_detail
|
215 |
+
)
|
216 |
+
except KeyError:
|
217 |
+
# guess previous_epoch_detail for older version
|
218 |
+
self._previous_epoch_detail = self.epoch + (
|
219 |
+
self.current_position - 1
|
220 |
+
) / len(self.batch_indices)
|
221 |
+
if self.epoch_detail > 0:
|
222 |
+
self._previous_epoch_detail = max(self._previous_epoch_detail, 0.0)
|
223 |
+
else:
|
224 |
+
self._previous_epoch_detail = -1.0
|
225 |
+
|
226 |
+
|
227 |
+
class MakeSymlinkToBestModel(extension.Extension):
|
228 |
+
"""Extension that makes a symbolic link to the best model
|
229 |
+
|
230 |
+
:param str key: Key of value
|
231 |
+
:param str prefix: Prefix of model files and link target
|
232 |
+
:param str suffix: Suffix of link target
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, key, prefix="model", suffix="best"):
|
236 |
+
super(MakeSymlinkToBestModel, self).__init__()
|
237 |
+
self.best_model = -1
|
238 |
+
self.min_loss = 0.0
|
239 |
+
self.key = key
|
240 |
+
self.prefix = prefix
|
241 |
+
self.suffix = suffix
|
242 |
+
|
243 |
+
def __call__(self, trainer):
|
244 |
+
observation = trainer.observation
|
245 |
+
if self.key in observation:
|
246 |
+
loss = observation[self.key]
|
247 |
+
if self.best_model == -1 or loss < self.min_loss:
|
248 |
+
self.min_loss = loss
|
249 |
+
self.best_model = trainer.updater.epoch
|
250 |
+
src = "%s.%d" % (self.prefix, self.best_model)
|
251 |
+
dest = os.path.join(trainer.out, "%s.%s" % (self.prefix, self.suffix))
|
252 |
+
if os.path.lexists(dest):
|
253 |
+
os.remove(dest)
|
254 |
+
os.symlink(src, dest)
|
255 |
+
logging.info("best model is " + src)
|
256 |
+
|
257 |
+
def serialize(self, serializer):
|
258 |
+
if isinstance(serializer, chainer.serializer.Serializer):
|
259 |
+
serializer("_best_model", self.best_model)
|
260 |
+
serializer("_min_loss", self.min_loss)
|
261 |
+
serializer("_key", self.key)
|
262 |
+
serializer("_prefix", self.prefix)
|
263 |
+
serializer("_suffix", self.suffix)
|
264 |
+
else:
|
265 |
+
self.best_model = serializer("_best_model", -1)
|
266 |
+
self.min_loss = serializer("_min_loss", 0.0)
|
267 |
+
self.key = serializer("_key", "")
|
268 |
+
self.prefix = serializer("_prefix", "model")
|
269 |
+
self.suffix = serializer("_suffix", "best")
|
270 |
+
|
271 |
+
|
272 |
+
# TODO(Hori): currently it only works with character-word level LM.
|
273 |
+
# need to consider any types of subwords-to-word mapping.
|
274 |
+
def make_lexical_tree(word_dict, subword_dict, word_unk):
|
275 |
+
"""Make a lexical tree to compute word-level probabilities"""
|
276 |
+
# node [dict(subword_id -> node), word_id, word_set[start-1, end]]
|
277 |
+
root = [{}, -1, None]
|
278 |
+
for w, wid in word_dict.items():
|
279 |
+
if wid > 0 and wid != word_unk: # skip <blank> and <unk>
|
280 |
+
if True in [c not in subword_dict for c in w]: # skip unknown subword
|
281 |
+
continue
|
282 |
+
succ = root[0] # get successors from root node
|
283 |
+
for i, c in enumerate(w):
|
284 |
+
cid = subword_dict[c]
|
285 |
+
if cid not in succ: # if next node does not exist, make a new node
|
286 |
+
succ[cid] = [{}, -1, (wid - 1, wid)]
|
287 |
+
else:
|
288 |
+
prev = succ[cid][2]
|
289 |
+
succ[cid][2] = (min(prev[0], wid - 1), max(prev[1], wid))
|
290 |
+
if i == len(w) - 1: # if word end, set word id
|
291 |
+
succ[cid][1] = wid
|
292 |
+
succ = succ[cid][0] # move to the child successors
|
293 |
+
return root
|
espnet/lm/pytorch_backend/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/lm/pytorch_backend/extlm.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2018 Mitsubishi Electric Research Laboratories (Takaaki Hori)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from espnet.lm.lm_utils import make_lexical_tree
|
14 |
+
from espnet.nets.pytorch_backend.nets_utils import to_device
|
15 |
+
|
16 |
+
|
17 |
+
# Definition of a multi-level (subword/word) language model
|
18 |
+
class MultiLevelLM(nn.Module):
|
19 |
+
logzero = -10000000000.0
|
20 |
+
zero = 1.0e-10
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
wordlm,
|
25 |
+
subwordlm,
|
26 |
+
word_dict,
|
27 |
+
subword_dict,
|
28 |
+
subwordlm_weight=0.8,
|
29 |
+
oov_penalty=1.0,
|
30 |
+
open_vocab=True,
|
31 |
+
):
|
32 |
+
super(MultiLevelLM, self).__init__()
|
33 |
+
self.wordlm = wordlm
|
34 |
+
self.subwordlm = subwordlm
|
35 |
+
self.word_eos = word_dict["<eos>"]
|
36 |
+
self.word_unk = word_dict["<unk>"]
|
37 |
+
self.var_word_eos = torch.LongTensor([self.word_eos])
|
38 |
+
self.var_word_unk = torch.LongTensor([self.word_unk])
|
39 |
+
self.space = subword_dict["<space>"]
|
40 |
+
self.eos = subword_dict["<eos>"]
|
41 |
+
self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk)
|
42 |
+
self.log_oov_penalty = math.log(oov_penalty)
|
43 |
+
self.open_vocab = open_vocab
|
44 |
+
self.subword_dict_size = len(subword_dict)
|
45 |
+
self.subwordlm_weight = subwordlm_weight
|
46 |
+
self.normalized = True
|
47 |
+
|
48 |
+
def forward(self, state, x):
|
49 |
+
# update state with input label x
|
50 |
+
if state is None: # make initial states and log-prob vectors
|
51 |
+
self.var_word_eos = to_device(x, self.var_word_eos)
|
52 |
+
self.var_word_unk = to_device(x, self.var_word_eos)
|
53 |
+
wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
|
54 |
+
wlm_logprobs = F.log_softmax(z_wlm, dim=1)
|
55 |
+
clm_state, z_clm = self.subwordlm(None, x)
|
56 |
+
log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight
|
57 |
+
new_node = self.lexroot
|
58 |
+
clm_logprob = 0.0
|
59 |
+
xi = self.space
|
60 |
+
else:
|
61 |
+
clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
|
62 |
+
xi = int(x)
|
63 |
+
if xi == self.space: # inter-word transition
|
64 |
+
if node is not None and node[1] >= 0: # check if the node is word end
|
65 |
+
w = to_device(x, torch.LongTensor([node[1]]))
|
66 |
+
else: # this node is not a word end, which means <unk>
|
67 |
+
w = self.var_word_unk
|
68 |
+
# update wordlm state and log-prob vector
|
69 |
+
wlm_state, z_wlm = self.wordlm(wlm_state, w)
|
70 |
+
wlm_logprobs = F.log_softmax(z_wlm, dim=1)
|
71 |
+
new_node = self.lexroot # move to the tree root
|
72 |
+
clm_logprob = 0.0
|
73 |
+
elif node is not None and xi in node[0]: # intra-word transition
|
74 |
+
new_node = node[0][xi]
|
75 |
+
clm_logprob += log_y[0, xi]
|
76 |
+
elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
|
77 |
+
new_node = None
|
78 |
+
clm_logprob += log_y[0, xi]
|
79 |
+
else: # if open_vocab flag is disabled, return 0 probabilities
|
80 |
+
log_y = to_device(
|
81 |
+
x, torch.full((1, self.subword_dict_size), self.logzero)
|
82 |
+
)
|
83 |
+
return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.0), log_y
|
84 |
+
|
85 |
+
clm_state, z_clm = self.subwordlm(clm_state, x)
|
86 |
+
log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight
|
87 |
+
|
88 |
+
# apply word-level probabilies for <space> and <eos> labels
|
89 |
+
if xi != self.space:
|
90 |
+
if new_node is not None and new_node[1] >= 0: # if new node is word end
|
91 |
+
wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob
|
92 |
+
else:
|
93 |
+
wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty
|
94 |
+
log_y[:, self.space] = wlm_logprob
|
95 |
+
log_y[:, self.eos] = wlm_logprob
|
96 |
+
else:
|
97 |
+
log_y[:, self.space] = self.logzero
|
98 |
+
log_y[:, self.eos] = self.logzero
|
99 |
+
|
100 |
+
return (
|
101 |
+
(clm_state, wlm_state, wlm_logprobs, new_node, log_y, float(clm_logprob)),
|
102 |
+
log_y,
|
103 |
+
)
|
104 |
+
|
105 |
+
def final(self, state):
|
106 |
+
clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
|
107 |
+
if node is not None and node[1] >= 0: # check if the node is word end
|
108 |
+
w = to_device(wlm_logprobs, torch.LongTensor([node[1]]))
|
109 |
+
else: # this node is not a word end, which means <unk>
|
110 |
+
w = self.var_word_unk
|
111 |
+
wlm_state, z_wlm = self.wordlm(wlm_state, w)
|
112 |
+
return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])
|
113 |
+
|
114 |
+
|
115 |
+
# Definition of a look-ahead word language model
|
116 |
+
class LookAheadWordLM(nn.Module):
|
117 |
+
logzero = -10000000000.0
|
118 |
+
zero = 1.0e-10
|
119 |
+
|
120 |
+
def __init__(
|
121 |
+
self, wordlm, word_dict, subword_dict, oov_penalty=0.0001, open_vocab=True
|
122 |
+
):
|
123 |
+
super(LookAheadWordLM, self).__init__()
|
124 |
+
self.wordlm = wordlm
|
125 |
+
self.word_eos = word_dict["<eos>"]
|
126 |
+
self.word_unk = word_dict["<unk>"]
|
127 |
+
self.var_word_eos = torch.LongTensor([self.word_eos])
|
128 |
+
self.var_word_unk = torch.LongTensor([self.word_unk])
|
129 |
+
self.space = subword_dict["<space>"]
|
130 |
+
self.eos = subword_dict["<eos>"]
|
131 |
+
self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk)
|
132 |
+
self.oov_penalty = oov_penalty
|
133 |
+
self.open_vocab = open_vocab
|
134 |
+
self.subword_dict_size = len(subword_dict)
|
135 |
+
self.zero_tensor = torch.FloatTensor([self.zero])
|
136 |
+
self.normalized = True
|
137 |
+
|
138 |
+
def forward(self, state, x):
|
139 |
+
# update state with input label x
|
140 |
+
if state is None: # make initial states and cumlative probability vector
|
141 |
+
self.var_word_eos = to_device(x, self.var_word_eos)
|
142 |
+
self.var_word_unk = to_device(x, self.var_word_eos)
|
143 |
+
self.zero_tensor = to_device(x, self.zero_tensor)
|
144 |
+
wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
|
145 |
+
cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
|
146 |
+
new_node = self.lexroot
|
147 |
+
xi = self.space
|
148 |
+
else:
|
149 |
+
wlm_state, cumsum_probs, node = state
|
150 |
+
xi = int(x)
|
151 |
+
if xi == self.space: # inter-word transition
|
152 |
+
if node is not None and node[1] >= 0: # check if the node is word end
|
153 |
+
w = to_device(x, torch.LongTensor([node[1]]))
|
154 |
+
else: # this node is not a word end, which means <unk>
|
155 |
+
w = self.var_word_unk
|
156 |
+
# update wordlm state and cumlative probability vector
|
157 |
+
wlm_state, z_wlm = self.wordlm(wlm_state, w)
|
158 |
+
cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
|
159 |
+
new_node = self.lexroot # move to the tree root
|
160 |
+
elif node is not None and xi in node[0]: # intra-word transition
|
161 |
+
new_node = node[0][xi]
|
162 |
+
elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
|
163 |
+
new_node = None
|
164 |
+
else: # if open_vocab flag is disabled, return 0 probabilities
|
165 |
+
log_y = to_device(
|
166 |
+
x, torch.full((1, self.subword_dict_size), self.logzero)
|
167 |
+
)
|
168 |
+
return (wlm_state, None, None), log_y
|
169 |
+
|
170 |
+
if new_node is not None:
|
171 |
+
succ, wid, wids = new_node
|
172 |
+
# compute parent node probability
|
173 |
+
sum_prob = (
|
174 |
+
(cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]])
|
175 |
+
if wids is not None
|
176 |
+
else 1.0
|
177 |
+
)
|
178 |
+
if sum_prob < self.zero:
|
179 |
+
log_y = to_device(
|
180 |
+
x, torch.full((1, self.subword_dict_size), self.logzero)
|
181 |
+
)
|
182 |
+
return (wlm_state, cumsum_probs, new_node), log_y
|
183 |
+
# set <unk> probability as a default value
|
184 |
+
unk_prob = (
|
185 |
+
cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1]
|
186 |
+
)
|
187 |
+
y = to_device(
|
188 |
+
x,
|
189 |
+
torch.full(
|
190 |
+
(1, self.subword_dict_size), float(unk_prob) * self.oov_penalty
|
191 |
+
),
|
192 |
+
)
|
193 |
+
# compute transition probabilities to child nodes
|
194 |
+
for cid, nd in succ.items():
|
195 |
+
y[:, cid] = (
|
196 |
+
cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]]
|
197 |
+
) / sum_prob
|
198 |
+
# apply word-level probabilies for <space> and <eos> labels
|
199 |
+
if wid >= 0:
|
200 |
+
wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob
|
201 |
+
y[:, self.space] = wlm_prob
|
202 |
+
y[:, self.eos] = wlm_prob
|
203 |
+
elif xi == self.space:
|
204 |
+
y[:, self.space] = self.zero
|
205 |
+
y[:, self.eos] = self.zero
|
206 |
+
log_y = torch.log(torch.max(y, self.zero_tensor)) # clip to avoid log(0)
|
207 |
+
else: # if no path in the tree, transition probability is one
|
208 |
+
log_y = to_device(x, torch.zeros(1, self.subword_dict_size))
|
209 |
+
return (wlm_state, cumsum_probs, new_node), log_y
|
210 |
+
|
211 |
+
def final(self, state):
|
212 |
+
wlm_state, cumsum_probs, node = state
|
213 |
+
if node is not None and node[1] >= 0: # check if the node is word end
|
214 |
+
w = to_device(cumsum_probs, torch.LongTensor([node[1]]))
|
215 |
+
else: # this node is not a word end, which means <unk>
|
216 |
+
w = self.var_word_unk
|
217 |
+
wlm_state, z_wlm = self.wordlm(wlm_state, w)
|
218 |
+
return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])
|
espnet/lm/pytorch_backend/lm.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
3 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
4 |
+
# This code is ported from the following implementation written in Torch.
|
5 |
+
# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
|
6 |
+
|
7 |
+
"""LM training in pytorch."""
|
8 |
+
|
9 |
+
import copy
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn.parallel import data_parallel
|
17 |
+
|
18 |
+
from chainer import Chain
|
19 |
+
from chainer.dataset import convert
|
20 |
+
from chainer import reporter
|
21 |
+
from chainer import training
|
22 |
+
from chainer.training import extensions
|
23 |
+
|
24 |
+
from espnet.lm.lm_utils import count_tokens
|
25 |
+
from espnet.lm.lm_utils import load_dataset
|
26 |
+
from espnet.lm.lm_utils import MakeSymlinkToBestModel
|
27 |
+
from espnet.lm.lm_utils import ParallelSentenceIterator
|
28 |
+
from espnet.lm.lm_utils import read_tokens
|
29 |
+
from espnet.nets.lm_interface import dynamic_import_lm
|
30 |
+
from espnet.nets.lm_interface import LMInterface
|
31 |
+
from espnet.optimizer.factory import dynamic_import_optimizer
|
32 |
+
from espnet.scheduler.pytorch import PyTorchScheduler
|
33 |
+
from espnet.scheduler.scheduler import dynamic_import_scheduler
|
34 |
+
|
35 |
+
from espnet.asr.asr_utils import snapshot_object
|
36 |
+
from espnet.asr.asr_utils import torch_load
|
37 |
+
from espnet.asr.asr_utils import torch_resume
|
38 |
+
from espnet.asr.asr_utils import torch_snapshot
|
39 |
+
|
40 |
+
from espnet.utils.training.tensorboard_logger import TensorboardLogger
|
41 |
+
from tensorboardX import SummaryWriter
|
42 |
+
|
43 |
+
from espnet.utils.deterministic_utils import set_deterministic_pytorch
|
44 |
+
from espnet.utils.training.evaluator import BaseEvaluator
|
45 |
+
from espnet.utils.training.iterators import ShufflingEnabler
|
46 |
+
from espnet.utils.training.train_utils import check_early_stop
|
47 |
+
from espnet.utils.training.train_utils import set_early_stop
|
48 |
+
|
49 |
+
|
50 |
+
def compute_perplexity(result):
|
51 |
+
"""Compute and add the perplexity to the LogReport.
|
52 |
+
|
53 |
+
:param dict result: The current observations
|
54 |
+
"""
|
55 |
+
# Routine to rewrite the result dictionary of LogReport to add perplexity values
|
56 |
+
result["perplexity"] = np.exp(result["main/nll"] / result["main/count"])
|
57 |
+
if "validation/main/nll" in result:
|
58 |
+
result["val_perplexity"] = np.exp(
|
59 |
+
result["validation/main/nll"] / result["validation/main/count"]
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
class Reporter(Chain):
|
64 |
+
"""Dummy module to use chainer's trainer."""
|
65 |
+
|
66 |
+
def report(self, loss):
|
67 |
+
"""Report nothing."""
|
68 |
+
pass
|
69 |
+
|
70 |
+
|
71 |
+
def concat_examples(batch, device=None, padding=None):
|
72 |
+
"""Concat examples in minibatch.
|
73 |
+
|
74 |
+
:param np.ndarray batch: The batch to concatenate
|
75 |
+
:param int device: The device to send to
|
76 |
+
:param Tuple[int,int] padding: The padding to use
|
77 |
+
:return: (inputs, targets)
|
78 |
+
:rtype (torch.Tensor, torch.Tensor)
|
79 |
+
"""
|
80 |
+
x, t = convert.concat_examples(batch, padding=padding)
|
81 |
+
x = torch.from_numpy(x)
|
82 |
+
t = torch.from_numpy(t)
|
83 |
+
if device is not None and device >= 0:
|
84 |
+
x = x.cuda(device)
|
85 |
+
t = t.cuda(device)
|
86 |
+
return x, t
|
87 |
+
|
88 |
+
|
89 |
+
class BPTTUpdater(training.StandardUpdater):
|
90 |
+
"""An updater for a pytorch LM."""
|
91 |
+
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
train_iter,
|
95 |
+
model,
|
96 |
+
optimizer,
|
97 |
+
schedulers,
|
98 |
+
device,
|
99 |
+
gradclip=None,
|
100 |
+
use_apex=False,
|
101 |
+
accum_grad=1,
|
102 |
+
):
|
103 |
+
"""Initialize class.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
train_iter (chainer.dataset.Iterator): The train iterator
|
107 |
+
model (LMInterface) : The model to update
|
108 |
+
optimizer (torch.optim.Optimizer): The optimizer for training
|
109 |
+
schedulers (espnet.scheduler.scheduler.SchedulerInterface):
|
110 |
+
The schedulers of `optimizer`
|
111 |
+
device (int): The device id
|
112 |
+
gradclip (float): The gradient clipping value to use
|
113 |
+
use_apex (bool): The flag to use Apex in backprop.
|
114 |
+
accum_grad (int): The number of gradient accumulation.
|
115 |
+
|
116 |
+
"""
|
117 |
+
super(BPTTUpdater, self).__init__(train_iter, optimizer)
|
118 |
+
self.model = model
|
119 |
+
self.device = device
|
120 |
+
self.gradclip = gradclip
|
121 |
+
self.use_apex = use_apex
|
122 |
+
self.scheduler = PyTorchScheduler(schedulers, optimizer)
|
123 |
+
self.accum_grad = accum_grad
|
124 |
+
|
125 |
+
# The core part of the update routine can be customized by overriding.
|
126 |
+
def update_core(self):
|
127 |
+
"""Update the model."""
|
128 |
+
# When we pass one iterator and optimizer to StandardUpdater.__init__,
|
129 |
+
# they are automatically named 'main'.
|
130 |
+
train_iter = self.get_iterator("main")
|
131 |
+
optimizer = self.get_optimizer("main")
|
132 |
+
# Progress the dataset iterator for sentences at each iteration.
|
133 |
+
self.model.zero_grad() # Clear the parameter gradients
|
134 |
+
accum = {"loss": 0.0, "nll": 0.0, "count": 0}
|
135 |
+
for _ in range(self.accum_grad):
|
136 |
+
batch = train_iter.__next__()
|
137 |
+
# Concatenate the token IDs to matrices and send them to the device
|
138 |
+
# self.converter does this job
|
139 |
+
# (it is chainer.dataset.concat_examples by default)
|
140 |
+
x, t = concat_examples(batch, device=self.device[0], padding=(0, -100))
|
141 |
+
if self.device[0] == -1:
|
142 |
+
loss, nll, count = self.model(x, t)
|
143 |
+
else:
|
144 |
+
# apex does not support torch.nn.DataParallel
|
145 |
+
loss, nll, count = data_parallel(self.model, (x, t), self.device)
|
146 |
+
|
147 |
+
# backward
|
148 |
+
loss = loss.mean() / self.accum_grad
|
149 |
+
if self.use_apex:
|
150 |
+
from apex import amp
|
151 |
+
|
152 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
153 |
+
scaled_loss.backward()
|
154 |
+
else:
|
155 |
+
loss.backward() # Backprop
|
156 |
+
# accumulate stats
|
157 |
+
accum["loss"] += float(loss)
|
158 |
+
accum["nll"] += float(nll.sum())
|
159 |
+
accum["count"] += int(count.sum())
|
160 |
+
|
161 |
+
for k, v in accum.items():
|
162 |
+
reporter.report({k: v}, optimizer.target)
|
163 |
+
if self.gradclip is not None:
|
164 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), self.gradclip)
|
165 |
+
optimizer.step() # Update the parameters
|
166 |
+
self.scheduler.step(n_iter=self.iteration)
|
167 |
+
|
168 |
+
|
169 |
+
class LMEvaluator(BaseEvaluator):
|
170 |
+
"""A custom evaluator for a pytorch LM."""
|
171 |
+
|
172 |
+
def __init__(self, val_iter, eval_model, reporter, device):
|
173 |
+
"""Initialize class.
|
174 |
+
|
175 |
+
:param chainer.dataset.Iterator val_iter : The validation iterator
|
176 |
+
:param LMInterface eval_model : The model to evaluate
|
177 |
+
:param chainer.Reporter reporter : The observations reporter
|
178 |
+
:param int device : The device id to use
|
179 |
+
|
180 |
+
"""
|
181 |
+
super(LMEvaluator, self).__init__(val_iter, reporter, device=-1)
|
182 |
+
self.model = eval_model
|
183 |
+
self.device = device
|
184 |
+
|
185 |
+
def evaluate(self):
|
186 |
+
"""Evaluate the model."""
|
187 |
+
val_iter = self.get_iterator("main")
|
188 |
+
loss = 0
|
189 |
+
nll = 0
|
190 |
+
count = 0
|
191 |
+
self.model.eval()
|
192 |
+
with torch.no_grad():
|
193 |
+
for batch in copy.copy(val_iter):
|
194 |
+
x, t = concat_examples(batch, device=self.device[0], padding=(0, -100))
|
195 |
+
if self.device[0] == -1:
|
196 |
+
l, n, c = self.model(x, t)
|
197 |
+
else:
|
198 |
+
# apex does not support torch.nn.DataParallel
|
199 |
+
l, n, c = data_parallel(self.model, (x, t), self.device)
|
200 |
+
loss += float(l.sum())
|
201 |
+
nll += float(n.sum())
|
202 |
+
count += int(c.sum())
|
203 |
+
self.model.train()
|
204 |
+
# report validation loss
|
205 |
+
observation = {}
|
206 |
+
with reporter.report_scope(observation):
|
207 |
+
reporter.report({"loss": loss}, self.model.reporter)
|
208 |
+
reporter.report({"nll": nll}, self.model.reporter)
|
209 |
+
reporter.report({"count": count}, self.model.reporter)
|
210 |
+
return observation
|
211 |
+
|
212 |
+
|
213 |
+
def train(args):
|
214 |
+
"""Train with the given args.
|
215 |
+
|
216 |
+
:param Namespace args: The program arguments
|
217 |
+
:param type model_class: LMInterface class for training
|
218 |
+
"""
|
219 |
+
model_class = dynamic_import_lm(args.model_module, args.backend)
|
220 |
+
assert issubclass(model_class, LMInterface), "model should implement LMInterface"
|
221 |
+
# display torch version
|
222 |
+
logging.info("torch version = " + torch.__version__)
|
223 |
+
|
224 |
+
set_deterministic_pytorch(args)
|
225 |
+
|
226 |
+
# check cuda and cudnn availability
|
227 |
+
if not torch.cuda.is_available():
|
228 |
+
logging.warning("cuda is not available")
|
229 |
+
|
230 |
+
# get special label ids
|
231 |
+
unk = args.char_list_dict["<unk>"]
|
232 |
+
eos = args.char_list_dict["<eos>"]
|
233 |
+
# read tokens as a sequence of sentences
|
234 |
+
val, n_val_tokens, n_val_oovs = load_dataset(
|
235 |
+
args.valid_label, args.char_list_dict, args.dump_hdf5_path
|
236 |
+
)
|
237 |
+
train, n_train_tokens, n_train_oovs = load_dataset(
|
238 |
+
args.train_label, args.char_list_dict, args.dump_hdf5_path
|
239 |
+
)
|
240 |
+
logging.info("#vocab = " + str(args.n_vocab))
|
241 |
+
logging.info("#sentences in the training data = " + str(len(train)))
|
242 |
+
logging.info("#tokens in the training data = " + str(n_train_tokens))
|
243 |
+
logging.info(
|
244 |
+
"oov rate in the training data = %.2f %%"
|
245 |
+
% (n_train_oovs / n_train_tokens * 100)
|
246 |
+
)
|
247 |
+
logging.info("#sentences in the validation data = " + str(len(val)))
|
248 |
+
logging.info("#tokens in the validation data = " + str(n_val_tokens))
|
249 |
+
logging.info(
|
250 |
+
"oov rate in the validation data = %.2f %%" % (n_val_oovs / n_val_tokens * 100)
|
251 |
+
)
|
252 |
+
|
253 |
+
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
|
254 |
+
# Create the dataset iterators
|
255 |
+
batch_size = args.batchsize * max(args.ngpu, 1)
|
256 |
+
if batch_size * args.accum_grad > args.batchsize:
|
257 |
+
logging.info(
|
258 |
+
f"batch size is automatically increased "
|
259 |
+
f"({args.batchsize} -> {batch_size * args.accum_grad})"
|
260 |
+
)
|
261 |
+
train_iter = ParallelSentenceIterator(
|
262 |
+
train,
|
263 |
+
batch_size,
|
264 |
+
max_length=args.maxlen,
|
265 |
+
sos=eos,
|
266 |
+
eos=eos,
|
267 |
+
shuffle=not use_sortagrad,
|
268 |
+
)
|
269 |
+
val_iter = ParallelSentenceIterator(
|
270 |
+
val, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False
|
271 |
+
)
|
272 |
+
epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad)
|
273 |
+
logging.info("#iterations per epoch = %d" % epoch_iters)
|
274 |
+
logging.info("#total iterations = " + str(args.epoch * epoch_iters))
|
275 |
+
# Prepare an RNNLM model
|
276 |
+
if args.train_dtype in ("float16", "float32", "float64"):
|
277 |
+
dtype = getattr(torch, args.train_dtype)
|
278 |
+
else:
|
279 |
+
dtype = torch.float32
|
280 |
+
model = model_class(args.n_vocab, args).to(dtype=dtype)
|
281 |
+
if args.ngpu > 0:
|
282 |
+
model.to("cuda")
|
283 |
+
gpu_id = list(range(args.ngpu))
|
284 |
+
else:
|
285 |
+
gpu_id = [-1]
|
286 |
+
|
287 |
+
# Save model conf to json
|
288 |
+
model_conf = args.outdir + "/model.json"
|
289 |
+
with open(model_conf, "wb") as f:
|
290 |
+
logging.info("writing a model config file to " + model_conf)
|
291 |
+
f.write(
|
292 |
+
json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode(
|
293 |
+
"utf_8"
|
294 |
+
)
|
295 |
+
)
|
296 |
+
|
297 |
+
logging.warning(
|
298 |
+
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
|
299 |
+
sum(p.numel() for p in model.parameters()),
|
300 |
+
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
301 |
+
sum(p.numel() for p in model.parameters() if p.requires_grad)
|
302 |
+
* 100.0
|
303 |
+
/ sum(p.numel() for p in model.parameters()),
|
304 |
+
)
|
305 |
+
)
|
306 |
+
|
307 |
+
# Set up an optimizer
|
308 |
+
opt_class = dynamic_import_optimizer(args.opt, args.backend)
|
309 |
+
optimizer = opt_class.from_args(model.parameters(), args)
|
310 |
+
if args.schedulers is None:
|
311 |
+
schedulers = []
|
312 |
+
else:
|
313 |
+
schedulers = [dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers]
|
314 |
+
|
315 |
+
# setup apex.amp
|
316 |
+
if args.train_dtype in ("O0", "O1", "O2", "O3"):
|
317 |
+
try:
|
318 |
+
from apex import amp
|
319 |
+
except ImportError as e:
|
320 |
+
logging.error(
|
321 |
+
f"You need to install apex for --train-dtype {args.train_dtype}. "
|
322 |
+
"See https://github.com/NVIDIA/apex#linux"
|
323 |
+
)
|
324 |
+
raise e
|
325 |
+
model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype)
|
326 |
+
use_apex = True
|
327 |
+
else:
|
328 |
+
use_apex = False
|
329 |
+
|
330 |
+
# FIXME: TOO DIRTY HACK
|
331 |
+
reporter = Reporter()
|
332 |
+
setattr(model, "reporter", reporter)
|
333 |
+
setattr(optimizer, "target", reporter)
|
334 |
+
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
|
335 |
+
|
336 |
+
updater = BPTTUpdater(
|
337 |
+
train_iter,
|
338 |
+
model,
|
339 |
+
optimizer,
|
340 |
+
schedulers,
|
341 |
+
gpu_id,
|
342 |
+
gradclip=args.gradclip,
|
343 |
+
use_apex=use_apex,
|
344 |
+
accum_grad=args.accum_grad,
|
345 |
+
)
|
346 |
+
trainer = training.Trainer(updater, (args.epoch, "epoch"), out=args.outdir)
|
347 |
+
trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id))
|
348 |
+
trainer.extend(
|
349 |
+
extensions.LogReport(
|
350 |
+
postprocess=compute_perplexity,
|
351 |
+
trigger=(args.report_interval_iters, "iteration"),
|
352 |
+
)
|
353 |
+
)
|
354 |
+
trainer.extend(
|
355 |
+
extensions.PrintReport(
|
356 |
+
[
|
357 |
+
"epoch",
|
358 |
+
"iteration",
|
359 |
+
"main/loss",
|
360 |
+
"perplexity",
|
361 |
+
"val_perplexity",
|
362 |
+
"elapsed_time",
|
363 |
+
]
|
364 |
+
),
|
365 |
+
trigger=(args.report_interval_iters, "iteration"),
|
366 |
+
)
|
367 |
+
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
|
368 |
+
# Save best models
|
369 |
+
trainer.extend(torch_snapshot(filename="snapshot.ep.{.updater.epoch}"))
|
370 |
+
trainer.extend(snapshot_object(model, "rnnlm.model.{.updater.epoch}"))
|
371 |
+
# T.Hori: MinValueTrigger should be used, but it fails when resuming
|
372 |
+
trainer.extend(MakeSymlinkToBestModel("validation/main/loss", "rnnlm.model"))
|
373 |
+
|
374 |
+
if use_sortagrad:
|
375 |
+
trainer.extend(
|
376 |
+
ShufflingEnabler([train_iter]),
|
377 |
+
trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, "epoch"),
|
378 |
+
)
|
379 |
+
if args.resume:
|
380 |
+
logging.info("resumed from %s" % args.resume)
|
381 |
+
torch_resume(args.resume, trainer)
|
382 |
+
|
383 |
+
set_early_stop(trainer, args, is_lm=True)
|
384 |
+
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
|
385 |
+
writer = SummaryWriter(args.tensorboard_dir)
|
386 |
+
trainer.extend(
|
387 |
+
TensorboardLogger(writer), trigger=(args.report_interval_iters, "iteration")
|
388 |
+
)
|
389 |
+
|
390 |
+
trainer.run()
|
391 |
+
check_early_stop(trainer, args.epoch)
|
392 |
+
|
393 |
+
# compute perplexity for test set
|
394 |
+
if args.test_label:
|
395 |
+
logging.info("test the best model")
|
396 |
+
torch_load(args.outdir + "/rnnlm.model.best", model)
|
397 |
+
test = read_tokens(args.test_label, args.char_list_dict)
|
398 |
+
n_test_tokens, n_test_oovs = count_tokens(test, unk)
|
399 |
+
logging.info("#sentences in the test data = " + str(len(test)))
|
400 |
+
logging.info("#tokens in the test data = " + str(n_test_tokens))
|
401 |
+
logging.info(
|
402 |
+
"oov rate in the test data = %.2f %%" % (n_test_oovs / n_test_tokens * 100)
|
403 |
+
)
|
404 |
+
test_iter = ParallelSentenceIterator(
|
405 |
+
test, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False
|
406 |
+
)
|
407 |
+
evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id)
|
408 |
+
result = evaluator()
|
409 |
+
compute_perplexity(result)
|
410 |
+
logging.info(f"test perplexity: {result['perplexity']}")
|
espnet/mt/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/mt/mt_utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# encoding: utf-8
|
3 |
+
|
4 |
+
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Utility funcitons for the text translation task."""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
|
11 |
+
|
12 |
+
# * ------------------ recognition related ------------------ *
|
13 |
+
def parse_hypothesis(hyp, char_list):
|
14 |
+
"""Parse hypothesis.
|
15 |
+
|
16 |
+
:param list hyp: recognition hypothesis
|
17 |
+
:param list char_list: list of characters
|
18 |
+
:return: recognition text string
|
19 |
+
:return: recognition token string
|
20 |
+
:return: recognition tokenid string
|
21 |
+
"""
|
22 |
+
# remove sos and get results
|
23 |
+
tokenid_as_list = list(map(int, hyp["yseq"][1:]))
|
24 |
+
token_as_list = [char_list[idx] for idx in tokenid_as_list]
|
25 |
+
score = float(hyp["score"])
|
26 |
+
|
27 |
+
# convert to string
|
28 |
+
tokenid = " ".join([str(idx) for idx in tokenid_as_list])
|
29 |
+
token = " ".join(token_as_list)
|
30 |
+
text = "".join(token_as_list).replace("<space>", " ")
|
31 |
+
|
32 |
+
return text, token, tokenid, score
|
33 |
+
|
34 |
+
|
35 |
+
def add_results_to_json(js, nbest_hyps, char_list):
|
36 |
+
"""Add N-best results to json.
|
37 |
+
|
38 |
+
:param dict js: groundtruth utterance dict
|
39 |
+
:param list nbest_hyps: list of hypothesis
|
40 |
+
:param list char_list: list of characters
|
41 |
+
:return: N-best results added utterance dict
|
42 |
+
"""
|
43 |
+
# copy old json info
|
44 |
+
new_js = dict()
|
45 |
+
if "utt2spk" in js.keys():
|
46 |
+
new_js["utt2spk"] = js["utt2spk"]
|
47 |
+
new_js["output"] = []
|
48 |
+
|
49 |
+
for n, hyp in enumerate(nbest_hyps, 1):
|
50 |
+
# parse hypothesis
|
51 |
+
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
|
52 |
+
|
53 |
+
# copy ground-truth
|
54 |
+
if len(js["output"]) > 0:
|
55 |
+
out_dic = dict(js["output"][0].items())
|
56 |
+
else:
|
57 |
+
out_dic = {"name": ""}
|
58 |
+
|
59 |
+
# update name
|
60 |
+
out_dic["name"] += "[%d]" % n
|
61 |
+
|
62 |
+
# add recognition results
|
63 |
+
out_dic["rec_text"] = rec_text
|
64 |
+
out_dic["rec_token"] = rec_token
|
65 |
+
out_dic["rec_tokenid"] = rec_tokenid
|
66 |
+
out_dic["score"] = score
|
67 |
+
|
68 |
+
# add source reference
|
69 |
+
out_dic["text_src"] = js["output"][1]["text"]
|
70 |
+
out_dic["token_src"] = js["output"][1]["token"]
|
71 |
+
out_dic["tokenid_src"] = js["output"][1]["tokenid"]
|
72 |
+
|
73 |
+
# add to list of N-best result dicts
|
74 |
+
new_js["output"].append(out_dic)
|
75 |
+
|
76 |
+
# show 1-best result
|
77 |
+
if n == 1:
|
78 |
+
if "text" in out_dic.keys():
|
79 |
+
logging.info("groundtruth: %s" % out_dic["text"])
|
80 |
+
logging.info("prediction : %s" % out_dic["rec_text"])
|
81 |
+
logging.info("source : %s" % out_dic["token_src"])
|
82 |
+
|
83 |
+
return new_js
|
espnet/mt/pytorch_backend/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/mt/pytorch_backend/mt.py
ADDED
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# encoding: utf-8
|
3 |
+
|
4 |
+
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Training/decoding definition for the text translation task."""
|
8 |
+
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
|
14 |
+
from chainer import training
|
15 |
+
from chainer.training import extensions
|
16 |
+
import numpy as np
|
17 |
+
from tensorboardX import SummaryWriter
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from espnet.asr.asr_utils import adadelta_eps_decay
|
21 |
+
from espnet.asr.asr_utils import adam_lr_decay
|
22 |
+
from espnet.asr.asr_utils import add_results_to_json
|
23 |
+
from espnet.asr.asr_utils import CompareValueTrigger
|
24 |
+
from espnet.asr.asr_utils import restore_snapshot
|
25 |
+
from espnet.asr.asr_utils import snapshot_object
|
26 |
+
from espnet.asr.asr_utils import torch_load
|
27 |
+
from espnet.asr.asr_utils import torch_resume
|
28 |
+
from espnet.asr.asr_utils import torch_snapshot
|
29 |
+
from espnet.nets.mt_interface import MTInterface
|
30 |
+
from espnet.nets.pytorch_backend.e2e_asr import pad_list
|
31 |
+
from espnet.utils.dataset import ChainerDataLoader
|
32 |
+
from espnet.utils.dataset import TransformDataset
|
33 |
+
from espnet.utils.deterministic_utils import set_deterministic_pytorch
|
34 |
+
from espnet.utils.dynamic_import import dynamic_import
|
35 |
+
from espnet.utils.io_utils import LoadInputsAndTargets
|
36 |
+
from espnet.utils.training.batchfy import make_batchset
|
37 |
+
from espnet.utils.training.iterators import ShufflingEnabler
|
38 |
+
from espnet.utils.training.tensorboard_logger import TensorboardLogger
|
39 |
+
from espnet.utils.training.train_utils import check_early_stop
|
40 |
+
from espnet.utils.training.train_utils import set_early_stop
|
41 |
+
|
42 |
+
from espnet.asr.pytorch_backend.asr import CustomEvaluator
|
43 |
+
from espnet.asr.pytorch_backend.asr import CustomUpdater
|
44 |
+
from espnet.asr.pytorch_backend.asr import load_trained_model
|
45 |
+
|
46 |
+
import matplotlib
|
47 |
+
|
48 |
+
matplotlib.use("Agg")
|
49 |
+
|
50 |
+
if sys.version_info[0] == 2:
|
51 |
+
from itertools import izip_longest as zip_longest
|
52 |
+
else:
|
53 |
+
from itertools import zip_longest as zip_longest
|
54 |
+
|
55 |
+
|
56 |
+
class CustomConverter(object):
|
57 |
+
"""Custom batch converter for Pytorch."""
|
58 |
+
|
59 |
+
def __init__(self):
|
60 |
+
"""Construct a CustomConverter object."""
|
61 |
+
self.ignore_id = -1
|
62 |
+
self.pad = 0
|
63 |
+
# NOTE: we reserve index:0 for <pad> although this is reserved for a blank class
|
64 |
+
# in ASR. However,
|
65 |
+
# blank labels are not used in NMT. To keep the vocabulary size,
|
66 |
+
# we use index:0 for padding instead of adding one more class.
|
67 |
+
|
68 |
+
def __call__(self, batch, device=torch.device("cpu")):
|
69 |
+
"""Transform a batch and send it to a device.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
batch (list): The batch to transform.
|
73 |
+
device (torch.device): The device to send to.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
tuple(torch.Tensor, torch.Tensor, torch.Tensor)
|
77 |
+
|
78 |
+
"""
|
79 |
+
# batch should be located in list
|
80 |
+
assert len(batch) == 1
|
81 |
+
xs, ys = batch[0]
|
82 |
+
|
83 |
+
# get batch of lengths of input sequences
|
84 |
+
ilens = np.array([x.shape[0] for x in xs])
|
85 |
+
|
86 |
+
# perform padding and convert to tensor
|
87 |
+
xs_pad = pad_list([torch.from_numpy(x).long() for x in xs], self.pad).to(device)
|
88 |
+
ilens = torch.from_numpy(ilens).to(device)
|
89 |
+
ys_pad = pad_list([torch.from_numpy(y).long() for y in ys], self.ignore_id).to(
|
90 |
+
device
|
91 |
+
)
|
92 |
+
|
93 |
+
return xs_pad, ilens, ys_pad
|
94 |
+
|
95 |
+
|
96 |
+
def train(args):
|
97 |
+
"""Train with the given args.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
args (namespace): The program arguments.
|
101 |
+
|
102 |
+
"""
|
103 |
+
set_deterministic_pytorch(args)
|
104 |
+
|
105 |
+
# check cuda availability
|
106 |
+
if not torch.cuda.is_available():
|
107 |
+
logging.warning("cuda is not available")
|
108 |
+
|
109 |
+
# get input and output dimension info
|
110 |
+
with open(args.valid_json, "rb") as f:
|
111 |
+
valid_json = json.load(f)["utts"]
|
112 |
+
utts = list(valid_json.keys())
|
113 |
+
idim = int(valid_json[utts[0]]["output"][1]["shape"][1])
|
114 |
+
odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
|
115 |
+
logging.info("#input dims : " + str(idim))
|
116 |
+
logging.info("#output dims: " + str(odim))
|
117 |
+
|
118 |
+
# specify model architecture
|
119 |
+
model_class = dynamic_import(args.model_module)
|
120 |
+
model = model_class(idim, odim, args)
|
121 |
+
assert isinstance(model, MTInterface)
|
122 |
+
|
123 |
+
# write model config
|
124 |
+
if not os.path.exists(args.outdir):
|
125 |
+
os.makedirs(args.outdir)
|
126 |
+
model_conf = args.outdir + "/model.json"
|
127 |
+
with open(model_conf, "wb") as f:
|
128 |
+
logging.info("writing a model config file to " + model_conf)
|
129 |
+
f.write(
|
130 |
+
json.dumps(
|
131 |
+
(idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
|
132 |
+
).encode("utf_8")
|
133 |
+
)
|
134 |
+
for key in sorted(vars(args).keys()):
|
135 |
+
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
|
136 |
+
|
137 |
+
reporter = model.reporter
|
138 |
+
|
139 |
+
# check the use of multi-gpu
|
140 |
+
if args.ngpu > 1:
|
141 |
+
if args.batch_size != 0:
|
142 |
+
logging.warning(
|
143 |
+
"batch size is automatically increased (%d -> %d)"
|
144 |
+
% (args.batch_size, args.batch_size * args.ngpu)
|
145 |
+
)
|
146 |
+
args.batch_size *= args.ngpu
|
147 |
+
|
148 |
+
# set torch device
|
149 |
+
device = torch.device("cuda" if args.ngpu > 0 else "cpu")
|
150 |
+
if args.train_dtype in ("float16", "float32", "float64"):
|
151 |
+
dtype = getattr(torch, args.train_dtype)
|
152 |
+
else:
|
153 |
+
dtype = torch.float32
|
154 |
+
model = model.to(device=device, dtype=dtype)
|
155 |
+
|
156 |
+
logging.warning(
|
157 |
+
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
|
158 |
+
sum(p.numel() for p in model.parameters()),
|
159 |
+
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
160 |
+
sum(p.numel() for p in model.parameters() if p.requires_grad)
|
161 |
+
* 100.0
|
162 |
+
/ sum(p.numel() for p in model.parameters()),
|
163 |
+
)
|
164 |
+
)
|
165 |
+
|
166 |
+
# Setup an optimizer
|
167 |
+
if args.opt == "adadelta":
|
168 |
+
optimizer = torch.optim.Adadelta(
|
169 |
+
model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay
|
170 |
+
)
|
171 |
+
elif args.opt == "adam":
|
172 |
+
optimizer = torch.optim.Adam(
|
173 |
+
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
|
174 |
+
)
|
175 |
+
elif args.opt == "noam":
|
176 |
+
from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
|
177 |
+
|
178 |
+
optimizer = get_std_opt(
|
179 |
+
model.parameters(),
|
180 |
+
args.adim,
|
181 |
+
args.transformer_warmup_steps,
|
182 |
+
args.transformer_lr,
|
183 |
+
)
|
184 |
+
else:
|
185 |
+
raise NotImplementedError("unknown optimizer: " + args.opt)
|
186 |
+
|
187 |
+
# setup apex.amp
|
188 |
+
if args.train_dtype in ("O0", "O1", "O2", "O3"):
|
189 |
+
try:
|
190 |
+
from apex import amp
|
191 |
+
except ImportError as e:
|
192 |
+
logging.error(
|
193 |
+
f"You need to install apex for --train-dtype {args.train_dtype}. "
|
194 |
+
"See https://github.com/NVIDIA/apex#linux"
|
195 |
+
)
|
196 |
+
raise e
|
197 |
+
if args.opt == "noam":
|
198 |
+
model, optimizer.optimizer = amp.initialize(
|
199 |
+
model, optimizer.optimizer, opt_level=args.train_dtype
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
model, optimizer = amp.initialize(
|
203 |
+
model, optimizer, opt_level=args.train_dtype
|
204 |
+
)
|
205 |
+
use_apex = True
|
206 |
+
else:
|
207 |
+
use_apex = False
|
208 |
+
|
209 |
+
# FIXME: TOO DIRTY HACK
|
210 |
+
setattr(optimizer, "target", reporter)
|
211 |
+
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
|
212 |
+
|
213 |
+
# Setup a converter
|
214 |
+
converter = CustomConverter()
|
215 |
+
|
216 |
+
# read json data
|
217 |
+
with open(args.train_json, "rb") as f:
|
218 |
+
train_json = json.load(f)["utts"]
|
219 |
+
with open(args.valid_json, "rb") as f:
|
220 |
+
valid_json = json.load(f)["utts"]
|
221 |
+
|
222 |
+
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
|
223 |
+
# make minibatch list (variable length)
|
224 |
+
train = make_batchset(
|
225 |
+
train_json,
|
226 |
+
args.batch_size,
|
227 |
+
args.maxlen_in,
|
228 |
+
args.maxlen_out,
|
229 |
+
args.minibatches,
|
230 |
+
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
|
231 |
+
shortest_first=use_sortagrad,
|
232 |
+
count=args.batch_count,
|
233 |
+
batch_bins=args.batch_bins,
|
234 |
+
batch_frames_in=args.batch_frames_in,
|
235 |
+
batch_frames_out=args.batch_frames_out,
|
236 |
+
batch_frames_inout=args.batch_frames_inout,
|
237 |
+
mt=True,
|
238 |
+
iaxis=1,
|
239 |
+
oaxis=0,
|
240 |
+
)
|
241 |
+
valid = make_batchset(
|
242 |
+
valid_json,
|
243 |
+
args.batch_size,
|
244 |
+
args.maxlen_in,
|
245 |
+
args.maxlen_out,
|
246 |
+
args.minibatches,
|
247 |
+
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
|
248 |
+
count=args.batch_count,
|
249 |
+
batch_bins=args.batch_bins,
|
250 |
+
batch_frames_in=args.batch_frames_in,
|
251 |
+
batch_frames_out=args.batch_frames_out,
|
252 |
+
batch_frames_inout=args.batch_frames_inout,
|
253 |
+
mt=True,
|
254 |
+
iaxis=1,
|
255 |
+
oaxis=0,
|
256 |
+
)
|
257 |
+
|
258 |
+
load_tr = LoadInputsAndTargets(mode="mt", load_output=True)
|
259 |
+
load_cv = LoadInputsAndTargets(mode="mt", load_output=True)
|
260 |
+
# hack to make batchsize argument as 1
|
261 |
+
# actual bathsize is included in a list
|
262 |
+
# default collate function converts numpy array to pytorch tensor
|
263 |
+
# we used an empty collate function instead which returns list
|
264 |
+
train_iter = ChainerDataLoader(
|
265 |
+
dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
|
266 |
+
batch_size=1,
|
267 |
+
num_workers=args.n_iter_processes,
|
268 |
+
shuffle=not use_sortagrad,
|
269 |
+
collate_fn=lambda x: x[0],
|
270 |
+
)
|
271 |
+
valid_iter = ChainerDataLoader(
|
272 |
+
dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
|
273 |
+
batch_size=1,
|
274 |
+
shuffle=False,
|
275 |
+
collate_fn=lambda x: x[0],
|
276 |
+
num_workers=args.n_iter_processes,
|
277 |
+
)
|
278 |
+
|
279 |
+
# Set up a trainer
|
280 |
+
updater = CustomUpdater(
|
281 |
+
model,
|
282 |
+
args.grad_clip,
|
283 |
+
{"main": train_iter},
|
284 |
+
optimizer,
|
285 |
+
device,
|
286 |
+
args.ngpu,
|
287 |
+
False,
|
288 |
+
args.accum_grad,
|
289 |
+
use_apex=use_apex,
|
290 |
+
)
|
291 |
+
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
|
292 |
+
|
293 |
+
if use_sortagrad:
|
294 |
+
trainer.extend(
|
295 |
+
ShufflingEnabler([train_iter]),
|
296 |
+
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
|
297 |
+
)
|
298 |
+
|
299 |
+
# Resume from a snapshot
|
300 |
+
if args.resume:
|
301 |
+
logging.info("resumed from %s" % args.resume)
|
302 |
+
torch_resume(args.resume, trainer)
|
303 |
+
|
304 |
+
# Evaluate the model with the test dataset for each epoch
|
305 |
+
if args.save_interval_iters > 0:
|
306 |
+
trainer.extend(
|
307 |
+
CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu),
|
308 |
+
trigger=(args.save_interval_iters, "iteration"),
|
309 |
+
)
|
310 |
+
else:
|
311 |
+
trainer.extend(
|
312 |
+
CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu)
|
313 |
+
)
|
314 |
+
|
315 |
+
# Save attention weight each epoch
|
316 |
+
if args.num_save_attention > 0:
|
317 |
+
# NOTE: sort it by output lengths
|
318 |
+
data = sorted(
|
319 |
+
list(valid_json.items())[: args.num_save_attention],
|
320 |
+
key=lambda x: int(x[1]["output"][0]["shape"][0]),
|
321 |
+
reverse=True,
|
322 |
+
)
|
323 |
+
if hasattr(model, "module"):
|
324 |
+
att_vis_fn = model.module.calculate_all_attentions
|
325 |
+
plot_class = model.module.attention_plot_class
|
326 |
+
else:
|
327 |
+
att_vis_fn = model.calculate_all_attentions
|
328 |
+
plot_class = model.attention_plot_class
|
329 |
+
att_reporter = plot_class(
|
330 |
+
att_vis_fn,
|
331 |
+
data,
|
332 |
+
args.outdir + "/att_ws",
|
333 |
+
converter=converter,
|
334 |
+
transform=load_cv,
|
335 |
+
device=device,
|
336 |
+
ikey="output",
|
337 |
+
iaxis=1,
|
338 |
+
)
|
339 |
+
trainer.extend(att_reporter, trigger=(1, "epoch"))
|
340 |
+
else:
|
341 |
+
att_reporter = None
|
342 |
+
|
343 |
+
# Make a plot for training and validation values
|
344 |
+
trainer.extend(
|
345 |
+
extensions.PlotReport(
|
346 |
+
["main/loss", "validation/main/loss"], "epoch", file_name="loss.png"
|
347 |
+
)
|
348 |
+
)
|
349 |
+
trainer.extend(
|
350 |
+
extensions.PlotReport(
|
351 |
+
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
|
352 |
+
)
|
353 |
+
)
|
354 |
+
trainer.extend(
|
355 |
+
extensions.PlotReport(
|
356 |
+
["main/ppl", "validation/main/ppl"], "epoch", file_name="ppl.png"
|
357 |
+
)
|
358 |
+
)
|
359 |
+
trainer.extend(
|
360 |
+
extensions.PlotReport(
|
361 |
+
["main/bleu", "validation/main/bleu"], "epoch", file_name="bleu.png"
|
362 |
+
)
|
363 |
+
)
|
364 |
+
|
365 |
+
# Save best models
|
366 |
+
trainer.extend(
|
367 |
+
snapshot_object(model, "model.loss.best"),
|
368 |
+
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
|
369 |
+
)
|
370 |
+
trainer.extend(
|
371 |
+
snapshot_object(model, "model.acc.best"),
|
372 |
+
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
|
373 |
+
)
|
374 |
+
|
375 |
+
# save snapshot which contains model and optimizer states
|
376 |
+
if args.save_interval_iters > 0:
|
377 |
+
trainer.extend(
|
378 |
+
torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
|
379 |
+
trigger=(args.save_interval_iters, "iteration"),
|
380 |
+
)
|
381 |
+
else:
|
382 |
+
trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
|
383 |
+
|
384 |
+
# epsilon decay in the optimizer
|
385 |
+
if args.opt == "adadelta":
|
386 |
+
if args.criterion == "acc":
|
387 |
+
trainer.extend(
|
388 |
+
restore_snapshot(
|
389 |
+
model, args.outdir + "/model.acc.best", load_fn=torch_load
|
390 |
+
),
|
391 |
+
trigger=CompareValueTrigger(
|
392 |
+
"validation/main/acc",
|
393 |
+
lambda best_value, current_value: best_value > current_value,
|
394 |
+
),
|
395 |
+
)
|
396 |
+
trainer.extend(
|
397 |
+
adadelta_eps_decay(args.eps_decay),
|
398 |
+
trigger=CompareValueTrigger(
|
399 |
+
"validation/main/acc",
|
400 |
+
lambda best_value, current_value: best_value > current_value,
|
401 |
+
),
|
402 |
+
)
|
403 |
+
elif args.criterion == "loss":
|
404 |
+
trainer.extend(
|
405 |
+
restore_snapshot(
|
406 |
+
model, args.outdir + "/model.loss.best", load_fn=torch_load
|
407 |
+
),
|
408 |
+
trigger=CompareValueTrigger(
|
409 |
+
"validation/main/loss",
|
410 |
+
lambda best_value, current_value: best_value < current_value,
|
411 |
+
),
|
412 |
+
)
|
413 |
+
trainer.extend(
|
414 |
+
adadelta_eps_decay(args.eps_decay),
|
415 |
+
trigger=CompareValueTrigger(
|
416 |
+
"validation/main/loss",
|
417 |
+
lambda best_value, current_value: best_value < current_value,
|
418 |
+
),
|
419 |
+
)
|
420 |
+
elif args.opt == "adam":
|
421 |
+
if args.criterion == "acc":
|
422 |
+
trainer.extend(
|
423 |
+
restore_snapshot(
|
424 |
+
model, args.outdir + "/model.acc.best", load_fn=torch_load
|
425 |
+
),
|
426 |
+
trigger=CompareValueTrigger(
|
427 |
+
"validation/main/acc",
|
428 |
+
lambda best_value, current_value: best_value > current_value,
|
429 |
+
),
|
430 |
+
)
|
431 |
+
trainer.extend(
|
432 |
+
adam_lr_decay(args.lr_decay),
|
433 |
+
trigger=CompareValueTrigger(
|
434 |
+
"validation/main/acc",
|
435 |
+
lambda best_value, current_value: best_value > current_value,
|
436 |
+
),
|
437 |
+
)
|
438 |
+
elif args.criterion == "loss":
|
439 |
+
trainer.extend(
|
440 |
+
restore_snapshot(
|
441 |
+
model, args.outdir + "/model.loss.best", load_fn=torch_load
|
442 |
+
),
|
443 |
+
trigger=CompareValueTrigger(
|
444 |
+
"validation/main/loss",
|
445 |
+
lambda best_value, current_value: best_value < current_value,
|
446 |
+
),
|
447 |
+
)
|
448 |
+
trainer.extend(
|
449 |
+
adam_lr_decay(args.lr_decay),
|
450 |
+
trigger=CompareValueTrigger(
|
451 |
+
"validation/main/loss",
|
452 |
+
lambda best_value, current_value: best_value < current_value,
|
453 |
+
),
|
454 |
+
)
|
455 |
+
|
456 |
+
# Write a log of evaluation statistics for each epoch
|
457 |
+
trainer.extend(
|
458 |
+
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
|
459 |
+
)
|
460 |
+
report_keys = [
|
461 |
+
"epoch",
|
462 |
+
"iteration",
|
463 |
+
"main/loss",
|
464 |
+
"validation/main/loss",
|
465 |
+
"main/acc",
|
466 |
+
"validation/main/acc",
|
467 |
+
"main/ppl",
|
468 |
+
"validation/main/ppl",
|
469 |
+
"elapsed_time",
|
470 |
+
]
|
471 |
+
if args.opt == "adadelta":
|
472 |
+
trainer.extend(
|
473 |
+
extensions.observe_value(
|
474 |
+
"eps",
|
475 |
+
lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
|
476 |
+
"eps"
|
477 |
+
],
|
478 |
+
),
|
479 |
+
trigger=(args.report_interval_iters, "iteration"),
|
480 |
+
)
|
481 |
+
report_keys.append("eps")
|
482 |
+
elif args.opt in ["adam", "noam"]:
|
483 |
+
trainer.extend(
|
484 |
+
extensions.observe_value(
|
485 |
+
"lr",
|
486 |
+
lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
|
487 |
+
"lr"
|
488 |
+
],
|
489 |
+
),
|
490 |
+
trigger=(args.report_interval_iters, "iteration"),
|
491 |
+
)
|
492 |
+
report_keys.append("lr")
|
493 |
+
if args.report_bleu:
|
494 |
+
report_keys.append("main/bleu")
|
495 |
+
report_keys.append("validation/main/bleu")
|
496 |
+
trainer.extend(
|
497 |
+
extensions.PrintReport(report_keys),
|
498 |
+
trigger=(args.report_interval_iters, "iteration"),
|
499 |
+
)
|
500 |
+
|
501 |
+
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
|
502 |
+
set_early_stop(trainer, args)
|
503 |
+
|
504 |
+
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
|
505 |
+
trainer.extend(
|
506 |
+
TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter),
|
507 |
+
trigger=(args.report_interval_iters, "iteration"),
|
508 |
+
)
|
509 |
+
# Run the training
|
510 |
+
trainer.run()
|
511 |
+
check_early_stop(trainer, args.epochs)
|
512 |
+
|
513 |
+
|
514 |
+
def trans(args):
|
515 |
+
"""Decode with the given args.
|
516 |
+
|
517 |
+
Args:
|
518 |
+
args (namespace): The program arguments.
|
519 |
+
|
520 |
+
"""
|
521 |
+
set_deterministic_pytorch(args)
|
522 |
+
model, train_args = load_trained_model(args.model)
|
523 |
+
assert isinstance(model, MTInterface)
|
524 |
+
model.trans_args = args
|
525 |
+
|
526 |
+
# gpu
|
527 |
+
if args.ngpu == 1:
|
528 |
+
gpu_id = list(range(args.ngpu))
|
529 |
+
logging.info("gpu id: " + str(gpu_id))
|
530 |
+
model.cuda()
|
531 |
+
|
532 |
+
# read json data
|
533 |
+
with open(args.trans_json, "rb") as f:
|
534 |
+
js = json.load(f)["utts"]
|
535 |
+
new_js = {}
|
536 |
+
|
537 |
+
# remove enmpy utterances
|
538 |
+
if train_args.multilingual:
|
539 |
+
js = {
|
540 |
+
k: v
|
541 |
+
for k, v in js.items()
|
542 |
+
if v["output"][0]["shape"][0] > 1 and v["output"][1]["shape"][0] > 1
|
543 |
+
}
|
544 |
+
else:
|
545 |
+
js = {
|
546 |
+
k: v
|
547 |
+
for k, v in js.items()
|
548 |
+
if v["output"][0]["shape"][0] > 0 and v["output"][1]["shape"][0] > 0
|
549 |
+
}
|
550 |
+
|
551 |
+
if args.batchsize == 0:
|
552 |
+
with torch.no_grad():
|
553 |
+
for idx, name in enumerate(js.keys(), 1):
|
554 |
+
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
|
555 |
+
feat = [js[name]["output"][1]["tokenid"].split()]
|
556 |
+
nbest_hyps = model.translate(feat, args, train_args.char_list)
|
557 |
+
new_js[name] = add_results_to_json(
|
558 |
+
js[name], nbest_hyps, train_args.char_list
|
559 |
+
)
|
560 |
+
|
561 |
+
else:
|
562 |
+
|
563 |
+
def grouper(n, iterable, fillvalue=None):
|
564 |
+
kargs = [iter(iterable)] * n
|
565 |
+
return zip_longest(*kargs, fillvalue=fillvalue)
|
566 |
+
|
567 |
+
# sort data
|
568 |
+
keys = list(js.keys())
|
569 |
+
feat_lens = [js[key]["output"][1]["shape"][0] for key in keys]
|
570 |
+
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
|
571 |
+
keys = [keys[i] for i in sorted_index]
|
572 |
+
|
573 |
+
with torch.no_grad():
|
574 |
+
for names in grouper(args.batchsize, keys, None):
|
575 |
+
names = [name for name in names if name]
|
576 |
+
feats = [
|
577 |
+
np.fromiter(
|
578 |
+
map(int, js[name]["output"][1]["tokenid"].split()),
|
579 |
+
dtype=np.int64,
|
580 |
+
)
|
581 |
+
for name in names
|
582 |
+
]
|
583 |
+
nbest_hyps = model.translate_batch(
|
584 |
+
feats,
|
585 |
+
args,
|
586 |
+
train_args.char_list,
|
587 |
+
)
|
588 |
+
|
589 |
+
for i, nbest_hyp in enumerate(nbest_hyps):
|
590 |
+
name = names[i]
|
591 |
+
new_js[name] = add_results_to_json(
|
592 |
+
js[name], nbest_hyp, train_args.char_list
|
593 |
+
)
|
594 |
+
|
595 |
+
with open(args.result_label, "wb") as f:
|
596 |
+
f.write(
|
597 |
+
json.dumps(
|
598 |
+
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
|
599 |
+
).encode("utf_8")
|
600 |
+
)
|
espnet/nets/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/nets/asr_interface.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""ASR Interface module."""
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
from espnet.bin.asr_train import get_parser
|
5 |
+
from espnet.utils.dynamic_import import dynamic_import
|
6 |
+
from espnet.utils.fill_missing_args import fill_missing_args
|
7 |
+
|
8 |
+
|
9 |
+
class ASRInterface:
|
10 |
+
"""ASR Interface for ESPnet model implementation."""
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def add_arguments(parser):
|
14 |
+
"""Add arguments to parser."""
|
15 |
+
return parser
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def build(cls, idim: int, odim: int, **kwargs):
|
19 |
+
"""Initialize this class with python-level args.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
idim (int): The number of an input feature dim.
|
23 |
+
odim (int): The number of output vocab.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
ASRinterface: A new instance of ASRInterface.
|
27 |
+
|
28 |
+
"""
|
29 |
+
|
30 |
+
def wrap(parser):
|
31 |
+
return get_parser(parser, required=False)
|
32 |
+
|
33 |
+
args = argparse.Namespace(**kwargs)
|
34 |
+
args = fill_missing_args(args, wrap)
|
35 |
+
args = fill_missing_args(args, cls.add_arguments)
|
36 |
+
return cls(idim, odim, args)
|
37 |
+
|
38 |
+
def forward(self, xs, ilens, ys):
|
39 |
+
"""Compute loss for training.
|
40 |
+
|
41 |
+
:param xs:
|
42 |
+
For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim)
|
43 |
+
For chainer, list of source sequences chainer.Variable
|
44 |
+
:param ilens: batch of lengths of source sequences (B)
|
45 |
+
For pytorch, torch.Tensor
|
46 |
+
For chainer, list of int
|
47 |
+
:param ys:
|
48 |
+
For pytorch, batch of padded source sequences torch.Tensor (B, Lmax)
|
49 |
+
For chainer, list of source sequences chainer.Variable
|
50 |
+
:return: loss value
|
51 |
+
:rtype: torch.Tensor for pytorch, chainer.Variable for chainer
|
52 |
+
"""
|
53 |
+
raise NotImplementedError("forward method is not implemented")
|
54 |
+
|
55 |
+
def recognize(self, x, recog_args, char_list=None, rnnlm=None):
|
56 |
+
"""Recognize x for evaluation.
|
57 |
+
|
58 |
+
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
|
59 |
+
:param namespace recog_args: argment namespace contraining options
|
60 |
+
:param list char_list: list of characters
|
61 |
+
:param torch.nn.Module rnnlm: language model module
|
62 |
+
:return: N-best decoding results
|
63 |
+
:rtype: list
|
64 |
+
"""
|
65 |
+
raise NotImplementedError("recognize method is not implemented")
|
66 |
+
|
67 |
+
def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None):
|
68 |
+
"""Beam search implementation for batch.
|
69 |
+
|
70 |
+
:param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
|
71 |
+
:param namespace recog_args: argument namespace containing options
|
72 |
+
:param list char_list: list of characters
|
73 |
+
:param torch.nn.Module rnnlm: language model module
|
74 |
+
:return: N-best decoding results
|
75 |
+
:rtype: list
|
76 |
+
"""
|
77 |
+
raise NotImplementedError("Batch decoding is not supported yet.")
|
78 |
+
|
79 |
+
def calculate_all_attentions(self, xs, ilens, ys):
|
80 |
+
"""Caluculate attention.
|
81 |
+
|
82 |
+
:param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...]
|
83 |
+
:param ndarray ilens: batch of lengths of input sequences (B)
|
84 |
+
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
|
85 |
+
:return: attention weights (B, Lmax, Tmax)
|
86 |
+
:rtype: float ndarray
|
87 |
+
"""
|
88 |
+
raise NotImplementedError("calculate_all_attentions method is not implemented")
|
89 |
+
|
90 |
+
def calculate_all_ctc_probs(self, xs, ilens, ys):
|
91 |
+
"""Caluculate CTC probability.
|
92 |
+
|
93 |
+
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
|
94 |
+
:param ndarray ilens: batch of lengths of input sequences (B)
|
95 |
+
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
|
96 |
+
:return: CTC probabilities (B, Tmax, vocab)
|
97 |
+
:rtype: float ndarray
|
98 |
+
"""
|
99 |
+
raise NotImplementedError("calculate_all_ctc_probs method is not implemented")
|
100 |
+
|
101 |
+
@property
|
102 |
+
def attention_plot_class(self):
|
103 |
+
"""Get attention plot class."""
|
104 |
+
from espnet.asr.asr_utils import PlotAttentionReport
|
105 |
+
|
106 |
+
return PlotAttentionReport
|
107 |
+
|
108 |
+
@property
|
109 |
+
def ctc_plot_class(self):
|
110 |
+
"""Get CTC plot class."""
|
111 |
+
from espnet.asr.asr_utils import PlotCTCReport
|
112 |
+
|
113 |
+
return PlotCTCReport
|
114 |
+
|
115 |
+
def get_total_subsampling_factor(self):
|
116 |
+
"""Get total subsampling factor."""
|
117 |
+
raise NotImplementedError(
|
118 |
+
"get_total_subsampling_factor method is not implemented"
|
119 |
+
)
|
120 |
+
|
121 |
+
def encode(self, feat):
|
122 |
+
"""Encode feature in `beam_search` (optional).
|
123 |
+
|
124 |
+
Args:
|
125 |
+
x (numpy.ndarray): input feature (T, D)
|
126 |
+
Returns:
|
127 |
+
torch.Tensor for pytorch, chainer.Variable for chainer:
|
128 |
+
encoded feature (T, D)
|
129 |
+
|
130 |
+
"""
|
131 |
+
raise NotImplementedError("encode method is not implemented")
|
132 |
+
|
133 |
+
def scorers(self):
|
134 |
+
"""Get scorers for `beam_search` (optional).
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
dict[str, ScorerInterface]: dict of `ScorerInterface` objects
|
138 |
+
|
139 |
+
"""
|
140 |
+
raise NotImplementedError("decoders method is not implemented")
|
141 |
+
|
142 |
+
|
143 |
+
predefined_asr = {
|
144 |
+
"pytorch": {
|
145 |
+
"rnn": "espnet.nets.pytorch_backend.e2e_asr:E2E",
|
146 |
+
"transducer": "espnet.nets.pytorch_backend.e2e_asr_transducer:E2E",
|
147 |
+
"transformer": "espnet.nets.pytorch_backend.e2e_asr_transformer:E2E",
|
148 |
+
"conformer": "espnet.nets.pytorch_backend.e2e_asr_conformer:E2E",
|
149 |
+
},
|
150 |
+
"chainer": {
|
151 |
+
"rnn": "espnet.nets.chainer_backend.e2e_asr:E2E",
|
152 |
+
"transformer": "espnet.nets.chainer_backend.e2e_asr_transformer:E2E",
|
153 |
+
},
|
154 |
+
}
|
155 |
+
|
156 |
+
|
157 |
+
def dynamic_import_asr(module, backend):
|
158 |
+
"""Import ASR models dynamically.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
module (str): module_name:class_name or alias in `predefined_asr`
|
162 |
+
backend (str): NN backend. e.g., pytorch, chainer
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
type: ASR class
|
166 |
+
|
167 |
+
"""
|
168 |
+
model_class = dynamic_import(module, predefined_asr.get(backend, dict()))
|
169 |
+
assert issubclass(
|
170 |
+
model_class, ASRInterface
|
171 |
+
), f"{module} does not implement ASRInterface"
|
172 |
+
return model_class
|
espnet/nets/batch_beam_search.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Parallel beam search module."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import Any
|
5 |
+
from typing import Dict
|
6 |
+
from typing import List
|
7 |
+
from typing import NamedTuple
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.nn.utils.rnn import pad_sequence
|
12 |
+
|
13 |
+
from espnet.nets.beam_search import BeamSearch
|
14 |
+
from espnet.nets.beam_search import Hypothesis
|
15 |
+
|
16 |
+
|
17 |
+
class BatchHypothesis(NamedTuple):
|
18 |
+
"""Batchfied/Vectorized hypothesis data type."""
|
19 |
+
|
20 |
+
yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen)
|
21 |
+
score: torch.Tensor = torch.tensor([]) # (batch,)
|
22 |
+
length: torch.Tensor = torch.tensor([]) # (batch,)
|
23 |
+
scores: Dict[str, torch.Tensor] = dict() # values: (batch,)
|
24 |
+
states: Dict[str, Dict] = dict()
|
25 |
+
|
26 |
+
def __len__(self) -> int:
|
27 |
+
"""Return a batch size."""
|
28 |
+
return len(self.length)
|
29 |
+
|
30 |
+
|
31 |
+
class BatchBeamSearch(BeamSearch):
|
32 |
+
"""Batch beam search implementation."""
|
33 |
+
|
34 |
+
def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
|
35 |
+
"""Convert list to batch."""
|
36 |
+
if len(hyps) == 0:
|
37 |
+
return BatchHypothesis()
|
38 |
+
return BatchHypothesis(
|
39 |
+
yseq=pad_sequence(
|
40 |
+
[h.yseq for h in hyps], batch_first=True, padding_value=self.eos
|
41 |
+
),
|
42 |
+
length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64),
|
43 |
+
score=torch.tensor([h.score for h in hyps]),
|
44 |
+
scores={k: torch.tensor([h.scores[k] for h in hyps]) for k in self.scorers},
|
45 |
+
states={k: [h.states[k] for h in hyps] for k in self.scorers},
|
46 |
+
)
|
47 |
+
|
48 |
+
def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis:
|
49 |
+
return BatchHypothesis(
|
50 |
+
yseq=hyps.yseq[ids],
|
51 |
+
score=hyps.score[ids],
|
52 |
+
length=hyps.length[ids],
|
53 |
+
scores={k: v[ids] for k, v in hyps.scores.items()},
|
54 |
+
states={
|
55 |
+
k: [self.scorers[k].select_state(v, i) for i in ids]
|
56 |
+
for k, v in hyps.states.items()
|
57 |
+
},
|
58 |
+
)
|
59 |
+
|
60 |
+
def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
|
61 |
+
return Hypothesis(
|
62 |
+
yseq=hyps.yseq[i, : hyps.length[i]],
|
63 |
+
score=hyps.score[i],
|
64 |
+
scores={k: v[i] for k, v in hyps.scores.items()},
|
65 |
+
states={
|
66 |
+
k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items()
|
67 |
+
},
|
68 |
+
)
|
69 |
+
|
70 |
+
def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
|
71 |
+
"""Revert batch to list."""
|
72 |
+
return [
|
73 |
+
Hypothesis(
|
74 |
+
yseq=batch_hyps.yseq[i][: batch_hyps.length[i]],
|
75 |
+
score=batch_hyps.score[i],
|
76 |
+
scores={k: batch_hyps.scores[k][i] for k in self.scorers},
|
77 |
+
states={
|
78 |
+
k: v.select_state(batch_hyps.states[k], i)
|
79 |
+
for k, v in self.scorers.items()
|
80 |
+
},
|
81 |
+
)
|
82 |
+
for i in range(len(batch_hyps.length))
|
83 |
+
]
|
84 |
+
|
85 |
+
def batch_beam(
|
86 |
+
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
87 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
88 |
+
"""Batch-compute topk full token ids and partial token ids.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
92 |
+
Its shape is `(n_beam, self.vocab_size)`.
|
93 |
+
ids (torch.Tensor): The partial token ids to compute topk.
|
94 |
+
Its shape is `(n_beam, self.pre_beam_size)`.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
98 |
+
The topk full (prev_hyp, new_token) ids
|
99 |
+
and partial (prev_hyp, new_token) ids.
|
100 |
+
Their shapes are all `(self.beam_size,)`
|
101 |
+
|
102 |
+
"""
|
103 |
+
top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
|
104 |
+
# Because of the flatten above, `top_ids` is organized as:
|
105 |
+
# [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
|
106 |
+
# where V is `self.n_vocab` and K is `self.beam_size`
|
107 |
+
prev_hyp_ids = top_ids // self.n_vocab
|
108 |
+
new_token_ids = top_ids % self.n_vocab
|
109 |
+
return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
|
110 |
+
|
111 |
+
def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
|
112 |
+
"""Get an initial hypothesis data.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
x (torch.Tensor): The encoder output feature
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Hypothesis: The initial hypothesis.
|
119 |
+
|
120 |
+
"""
|
121 |
+
init_states = dict()
|
122 |
+
init_scores = dict()
|
123 |
+
for k, d in self.scorers.items():
|
124 |
+
init_states[k] = d.batch_init_state(x)
|
125 |
+
init_scores[k] = 0.0
|
126 |
+
return self.batchfy(
|
127 |
+
[
|
128 |
+
Hypothesis(
|
129 |
+
score=0.0,
|
130 |
+
scores=init_scores,
|
131 |
+
states=init_states,
|
132 |
+
yseq=torch.tensor([self.sos], device=x.device),
|
133 |
+
)
|
134 |
+
]
|
135 |
+
)
|
136 |
+
|
137 |
+
def score_full(
|
138 |
+
self, hyp: BatchHypothesis, x: torch.Tensor
|
139 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
140 |
+
"""Score new hypothesis by `self.full_scorers`.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
144 |
+
x (torch.Tensor): Corresponding input feature
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
148 |
+
score dict of `hyp` that has string keys of `self.full_scorers`
|
149 |
+
and tensor score values of shape: `(self.n_vocab,)`,
|
150 |
+
and state dict that has string keys
|
151 |
+
and state values of `self.full_scorers`
|
152 |
+
|
153 |
+
"""
|
154 |
+
scores = dict()
|
155 |
+
states = dict()
|
156 |
+
for k, d in self.full_scorers.items():
|
157 |
+
scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x)
|
158 |
+
return scores, states
|
159 |
+
|
160 |
+
def score_partial(
|
161 |
+
self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor
|
162 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
163 |
+
"""Score new hypothesis by `self.full_scorers`.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
167 |
+
ids (torch.Tensor): 2D tensor of new partial tokens to score
|
168 |
+
x (torch.Tensor): Corresponding input feature
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
172 |
+
score dict of `hyp` that has string keys of `self.full_scorers`
|
173 |
+
and tensor score values of shape: `(self.n_vocab,)`,
|
174 |
+
and state dict that has string keys
|
175 |
+
and state values of `self.full_scorers`
|
176 |
+
|
177 |
+
"""
|
178 |
+
scores = dict()
|
179 |
+
states = dict()
|
180 |
+
for k, d in self.part_scorers.items():
|
181 |
+
scores[k], states[k] = d.batch_score_partial(
|
182 |
+
hyp.yseq, ids, hyp.states[k], x
|
183 |
+
)
|
184 |
+
return scores, states
|
185 |
+
|
186 |
+
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
187 |
+
"""Merge states for new hypothesis.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
states: states of `self.full_scorers`
|
191 |
+
part_states: states of `self.part_scorers`
|
192 |
+
part_idx (int): The new token id for `part_scores`
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
Dict[str, torch.Tensor]: The new score dict.
|
196 |
+
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
197 |
+
Its values are states of the scorers.
|
198 |
+
|
199 |
+
"""
|
200 |
+
new_states = dict()
|
201 |
+
for k, v in states.items():
|
202 |
+
new_states[k] = v
|
203 |
+
for k, v in part_states.items():
|
204 |
+
new_states[k] = v
|
205 |
+
return new_states
|
206 |
+
|
207 |
+
def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
|
208 |
+
"""Search new tokens for running hypotheses and encoded speech x.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
running_hyps (BatchHypothesis): Running hypotheses on beam
|
212 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
BatchHypothesis: Best sorted hypotheses
|
216 |
+
|
217 |
+
"""
|
218 |
+
n_batch = len(running_hyps)
|
219 |
+
part_ids = None # no pre-beam
|
220 |
+
# batch scoring
|
221 |
+
weighted_scores = torch.zeros(
|
222 |
+
n_batch, self.n_vocab, dtype=x.dtype, device=x.device
|
223 |
+
)
|
224 |
+
scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
|
225 |
+
for k in self.full_scorers:
|
226 |
+
weighted_scores += self.weights[k] * scores[k]
|
227 |
+
# partial scoring
|
228 |
+
if self.do_pre_beam:
|
229 |
+
pre_beam_scores = (
|
230 |
+
weighted_scores
|
231 |
+
if self.pre_beam_score_key == "full"
|
232 |
+
else scores[self.pre_beam_score_key]
|
233 |
+
)
|
234 |
+
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1]
|
235 |
+
# NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns
|
236 |
+
# full-size score matrices, which has non-zero scores for part_ids and zeros
|
237 |
+
# for others.
|
238 |
+
part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
|
239 |
+
for k in self.part_scorers:
|
240 |
+
weighted_scores += self.weights[k] * part_scores[k]
|
241 |
+
# add previous hyp scores
|
242 |
+
weighted_scores += running_hyps.score.to(
|
243 |
+
dtype=x.dtype, device=x.device
|
244 |
+
).unsqueeze(1)
|
245 |
+
|
246 |
+
# TODO(karita): do not use list. use batch instead
|
247 |
+
# see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029
|
248 |
+
# update hyps
|
249 |
+
best_hyps = []
|
250 |
+
prev_hyps = self.unbatchfy(running_hyps)
|
251 |
+
for (
|
252 |
+
full_prev_hyp_id,
|
253 |
+
full_new_token_id,
|
254 |
+
part_prev_hyp_id,
|
255 |
+
part_new_token_id,
|
256 |
+
) in zip(*self.batch_beam(weighted_scores, part_ids)):
|
257 |
+
prev_hyp = prev_hyps[full_prev_hyp_id]
|
258 |
+
best_hyps.append(
|
259 |
+
Hypothesis(
|
260 |
+
score=weighted_scores[full_prev_hyp_id, full_new_token_id],
|
261 |
+
yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
|
262 |
+
scores=self.merge_scores(
|
263 |
+
prev_hyp.scores,
|
264 |
+
{k: v[full_prev_hyp_id] for k, v in scores.items()},
|
265 |
+
full_new_token_id,
|
266 |
+
{k: v[part_prev_hyp_id] for k, v in part_scores.items()},
|
267 |
+
part_new_token_id,
|
268 |
+
),
|
269 |
+
states=self.merge_states(
|
270 |
+
{
|
271 |
+
k: self.full_scorers[k].select_state(v, full_prev_hyp_id)
|
272 |
+
for k, v in states.items()
|
273 |
+
},
|
274 |
+
{
|
275 |
+
k: self.part_scorers[k].select_state(
|
276 |
+
v, part_prev_hyp_id, part_new_token_id
|
277 |
+
)
|
278 |
+
for k, v in part_states.items()
|
279 |
+
},
|
280 |
+
part_new_token_id,
|
281 |
+
),
|
282 |
+
)
|
283 |
+
)
|
284 |
+
return self.batchfy(best_hyps)
|
285 |
+
|
286 |
+
def post_process(
|
287 |
+
self,
|
288 |
+
i: int,
|
289 |
+
maxlen: int,
|
290 |
+
maxlenratio: float,
|
291 |
+
running_hyps: BatchHypothesis,
|
292 |
+
ended_hyps: List[Hypothesis],
|
293 |
+
) -> BatchHypothesis:
|
294 |
+
"""Perform post-processing of beam search iterations.
|
295 |
+
|
296 |
+
Args:
|
297 |
+
i (int): The length of hypothesis tokens.
|
298 |
+
maxlen (int): The maximum length of tokens in beam search.
|
299 |
+
maxlenratio (int): The maximum length ratio in beam search.
|
300 |
+
running_hyps (BatchHypothesis): The running hypotheses in beam search.
|
301 |
+
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
BatchHypothesis: The new running hypotheses.
|
305 |
+
|
306 |
+
"""
|
307 |
+
n_batch = running_hyps.yseq.shape[0]
|
308 |
+
logging.debug(f"the number of running hypothes: {n_batch}")
|
309 |
+
if self.token_list is not None:
|
310 |
+
logging.debug(
|
311 |
+
"best hypo: "
|
312 |
+
+ "".join(
|
313 |
+
[
|
314 |
+
self.token_list[x]
|
315 |
+
for x in running_hyps.yseq[0, 1 : running_hyps.length[0]]
|
316 |
+
]
|
317 |
+
)
|
318 |
+
)
|
319 |
+
# add eos in the final loop to avoid that there are no ended hyps
|
320 |
+
if i == maxlen - 1:
|
321 |
+
logging.info("adding <eos> in the last position in the loop")
|
322 |
+
yseq_eos = torch.cat(
|
323 |
+
(
|
324 |
+
running_hyps.yseq,
|
325 |
+
torch.full(
|
326 |
+
(n_batch, 1),
|
327 |
+
self.eos,
|
328 |
+
device=running_hyps.yseq.device,
|
329 |
+
dtype=torch.int64,
|
330 |
+
),
|
331 |
+
),
|
332 |
+
1,
|
333 |
+
)
|
334 |
+
running_hyps.yseq.resize_as_(yseq_eos)
|
335 |
+
running_hyps.yseq[:] = yseq_eos
|
336 |
+
running_hyps.length[:] = yseq_eos.shape[1]
|
337 |
+
|
338 |
+
# add ended hypotheses to a final list, and removed them from current hypotheses
|
339 |
+
# (this will be a probmlem, number of hyps < beam)
|
340 |
+
is_eos = (
|
341 |
+
running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1]
|
342 |
+
== self.eos
|
343 |
+
)
|
344 |
+
for b in torch.nonzero(is_eos).view(-1):
|
345 |
+
hyp = self._select(running_hyps, b)
|
346 |
+
ended_hyps.append(hyp)
|
347 |
+
remained_ids = torch.nonzero(is_eos == 0).view(-1)
|
348 |
+
return self._batch_select(running_hyps, remained_ids)
|
espnet/nets/batch_beam_search_online_sim.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Parallel beam search module for online simulation."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import yaml
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from espnet.nets.batch_beam_search import BatchBeamSearch
|
12 |
+
from espnet.nets.beam_search import Hypothesis
|
13 |
+
from espnet.nets.e2e_asr_common import end_detect
|
14 |
+
|
15 |
+
|
16 |
+
class BatchBeamSearchOnlineSim(BatchBeamSearch):
|
17 |
+
"""Online beam search implementation.
|
18 |
+
|
19 |
+
This simulates streaming decoding.
|
20 |
+
It requires encoded features of entire utterance and
|
21 |
+
extracts block by block from it as it shoud be done
|
22 |
+
in streaming processing.
|
23 |
+
This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR
|
24 |
+
WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH"
|
25 |
+
(https://arxiv.org/abs/2006.14941).
|
26 |
+
"""
|
27 |
+
|
28 |
+
def set_streaming_config(self, asr_config: str):
|
29 |
+
"""Set config file for streaming decoding.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
asr_config (str): The config file for asr training
|
33 |
+
|
34 |
+
"""
|
35 |
+
train_config_file = Path(asr_config)
|
36 |
+
self.block_size = None
|
37 |
+
self.hop_size = None
|
38 |
+
self.look_ahead = None
|
39 |
+
config = None
|
40 |
+
with train_config_file.open("r", encoding="utf-8") as f:
|
41 |
+
args = yaml.safe_load(f)
|
42 |
+
if "encoder_conf" in args.keys():
|
43 |
+
if "block_size" in args["encoder_conf"].keys():
|
44 |
+
self.block_size = args["encoder_conf"]["block_size"]
|
45 |
+
if "hop_size" in args["encoder_conf"].keys():
|
46 |
+
self.hop_size = args["encoder_conf"]["hop_size"]
|
47 |
+
if "look_ahead" in args["encoder_conf"].keys():
|
48 |
+
self.look_ahead = args["encoder_conf"]["look_ahead"]
|
49 |
+
elif "config" in args.keys():
|
50 |
+
config = args["config"]
|
51 |
+
if config is None:
|
52 |
+
logging.info(
|
53 |
+
"Cannot find config file for streaming decoding: "
|
54 |
+
+ "apply batch beam search instead."
|
55 |
+
)
|
56 |
+
return
|
57 |
+
if (
|
58 |
+
self.block_size is None or self.hop_size is None or self.look_ahead is None
|
59 |
+
) and config is not None:
|
60 |
+
config_file = Path(config)
|
61 |
+
with config_file.open("r", encoding="utf-8") as f:
|
62 |
+
args = yaml.safe_load(f)
|
63 |
+
if "encoder_conf" in args.keys():
|
64 |
+
enc_args = args["encoder_conf"]
|
65 |
+
if enc_args and "block_size" in enc_args:
|
66 |
+
self.block_size = enc_args["block_size"]
|
67 |
+
if enc_args and "hop_size" in enc_args:
|
68 |
+
self.hop_size = enc_args["hop_size"]
|
69 |
+
if enc_args and "look_ahead" in enc_args:
|
70 |
+
self.look_ahead = enc_args["look_ahead"]
|
71 |
+
|
72 |
+
def set_block_size(self, block_size: int):
|
73 |
+
"""Set block size for streaming decoding.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
block_size (int): The block size of encoder
|
77 |
+
"""
|
78 |
+
self.block_size = block_size
|
79 |
+
|
80 |
+
def set_hop_size(self, hop_size: int):
|
81 |
+
"""Set hop size for streaming decoding.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
hop_size (int): The hop size of encoder
|
85 |
+
"""
|
86 |
+
self.hop_size = hop_size
|
87 |
+
|
88 |
+
def set_look_ahead(self, look_ahead: int):
|
89 |
+
"""Set look ahead size for streaming decoding.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
look_ahead (int): The look ahead size of encoder
|
93 |
+
"""
|
94 |
+
self.look_ahead = look_ahead
|
95 |
+
|
96 |
+
def forward(
|
97 |
+
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
|
98 |
+
) -> List[Hypothesis]:
|
99 |
+
"""Perform beam search.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
103 |
+
maxlenratio (float): Input length ratio to obtain max output length.
|
104 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
105 |
+
to automatically find maximum hypothesis lengths
|
106 |
+
minlenratio (float): Input length ratio to obtain min output length.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
list[Hypothesis]: N-best decoding results
|
110 |
+
|
111 |
+
"""
|
112 |
+
self.conservative = True # always true
|
113 |
+
|
114 |
+
if self.block_size and self.hop_size and self.look_ahead:
|
115 |
+
cur_end_frame = int(self.block_size - self.look_ahead)
|
116 |
+
else:
|
117 |
+
cur_end_frame = x.shape[0]
|
118 |
+
process_idx = 0
|
119 |
+
if cur_end_frame < x.shape[0]:
|
120 |
+
h = x.narrow(0, 0, cur_end_frame)
|
121 |
+
else:
|
122 |
+
h = x
|
123 |
+
|
124 |
+
# set length bounds
|
125 |
+
if maxlenratio == 0:
|
126 |
+
maxlen = x.shape[0]
|
127 |
+
else:
|
128 |
+
maxlen = max(1, int(maxlenratio * x.size(0)))
|
129 |
+
minlen = int(minlenratio * x.size(0))
|
130 |
+
logging.info("decoder input length: " + str(x.shape[0]))
|
131 |
+
logging.info("max output length: " + str(maxlen))
|
132 |
+
logging.info("min output length: " + str(minlen))
|
133 |
+
|
134 |
+
# main loop of prefix search
|
135 |
+
running_hyps = self.init_hyp(h)
|
136 |
+
prev_hyps = []
|
137 |
+
ended_hyps = []
|
138 |
+
prev_repeat = False
|
139 |
+
|
140 |
+
continue_decode = True
|
141 |
+
|
142 |
+
while continue_decode:
|
143 |
+
move_to_next_block = False
|
144 |
+
if cur_end_frame < x.shape[0]:
|
145 |
+
h = x.narrow(0, 0, cur_end_frame)
|
146 |
+
else:
|
147 |
+
h = x
|
148 |
+
|
149 |
+
# extend states for ctc
|
150 |
+
self.extend(h, running_hyps)
|
151 |
+
|
152 |
+
while process_idx < maxlen:
|
153 |
+
logging.debug("position " + str(process_idx))
|
154 |
+
best = self.search(running_hyps, h)
|
155 |
+
|
156 |
+
if process_idx == maxlen - 1:
|
157 |
+
# end decoding
|
158 |
+
running_hyps = self.post_process(
|
159 |
+
process_idx, maxlen, maxlenratio, best, ended_hyps
|
160 |
+
)
|
161 |
+
n_batch = best.yseq.shape[0]
|
162 |
+
local_ended_hyps = []
|
163 |
+
is_local_eos = (
|
164 |
+
best.yseq[torch.arange(n_batch), best.length - 1] == self.eos
|
165 |
+
)
|
166 |
+
for i in range(is_local_eos.shape[0]):
|
167 |
+
if is_local_eos[i]:
|
168 |
+
hyp = self._select(best, i)
|
169 |
+
local_ended_hyps.append(hyp)
|
170 |
+
# NOTE(tsunoo): check repetitions here
|
171 |
+
# This is a implicit implementation of
|
172 |
+
# Eq (11) in https://arxiv.org/abs/2006.14941
|
173 |
+
# A flag prev_repeat is used instead of using set
|
174 |
+
elif (
|
175 |
+
not prev_repeat
|
176 |
+
and best.yseq[i, -1] in best.yseq[i, :-1]
|
177 |
+
and cur_end_frame < x.shape[0]
|
178 |
+
):
|
179 |
+
move_to_next_block = True
|
180 |
+
prev_repeat = True
|
181 |
+
if maxlenratio == 0.0 and end_detect(
|
182 |
+
[lh.asdict() for lh in local_ended_hyps], process_idx
|
183 |
+
):
|
184 |
+
logging.info(f"end detected at {process_idx}")
|
185 |
+
continue_decode = False
|
186 |
+
break
|
187 |
+
if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]:
|
188 |
+
move_to_next_block = True
|
189 |
+
|
190 |
+
if move_to_next_block:
|
191 |
+
if (
|
192 |
+
self.hop_size
|
193 |
+
and cur_end_frame + int(self.hop_size) + int(self.look_ahead)
|
194 |
+
< x.shape[0]
|
195 |
+
):
|
196 |
+
cur_end_frame += int(self.hop_size)
|
197 |
+
else:
|
198 |
+
cur_end_frame = x.shape[0]
|
199 |
+
logging.debug("Going to next block: %d", cur_end_frame)
|
200 |
+
if process_idx > 1 and len(prev_hyps) > 0 and self.conservative:
|
201 |
+
running_hyps = prev_hyps
|
202 |
+
process_idx -= 1
|
203 |
+
prev_hyps = []
|
204 |
+
break
|
205 |
+
|
206 |
+
prev_repeat = False
|
207 |
+
prev_hyps = running_hyps
|
208 |
+
running_hyps = self.post_process(
|
209 |
+
process_idx, maxlen, maxlenratio, best, ended_hyps
|
210 |
+
)
|
211 |
+
|
212 |
+
if cur_end_frame >= x.shape[0]:
|
213 |
+
for hyp in local_ended_hyps:
|
214 |
+
ended_hyps.append(hyp)
|
215 |
+
|
216 |
+
if len(running_hyps) == 0:
|
217 |
+
logging.info("no hypothesis. Finish decoding.")
|
218 |
+
continue_decode = False
|
219 |
+
break
|
220 |
+
else:
|
221 |
+
logging.debug(f"remained hypotheses: {len(running_hyps)}")
|
222 |
+
# increment number
|
223 |
+
process_idx += 1
|
224 |
+
|
225 |
+
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
|
226 |
+
# check the number of hypotheses reaching to eos
|
227 |
+
if len(nbest_hyps) == 0:
|
228 |
+
logging.warning(
|
229 |
+
"there is no N-best results, perform recognition "
|
230 |
+
"again with smaller minlenratio."
|
231 |
+
)
|
232 |
+
return (
|
233 |
+
[]
|
234 |
+
if minlenratio < 0.1
|
235 |
+
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
|
236 |
+
)
|
237 |
+
|
238 |
+
# report the best result
|
239 |
+
best = nbest_hyps[0]
|
240 |
+
for k, v in best.scores.items():
|
241 |
+
logging.info(
|
242 |
+
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
|
243 |
+
)
|
244 |
+
logging.info(f"total log probability: {best.score:.2f}")
|
245 |
+
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
|
246 |
+
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
|
247 |
+
if self.token_list is not None:
|
248 |
+
logging.info(
|
249 |
+
"best hypo: "
|
250 |
+
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
|
251 |
+
+ "\n"
|
252 |
+
)
|
253 |
+
return nbest_hyps
|
254 |
+
|
255 |
+
def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]:
|
256 |
+
"""Extend probabilities and states with more encoded chunks.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
x (torch.Tensor): The extended encoder output feature
|
260 |
+
hyps (Hypothesis): Current list of hypothesis
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
Hypothesis: The exxtended hypothesis
|
264 |
+
|
265 |
+
"""
|
266 |
+
for k, d in self.scorers.items():
|
267 |
+
if hasattr(d, "extend_prob"):
|
268 |
+
d.extend_prob(x)
|
269 |
+
if hasattr(d, "extend_state"):
|
270 |
+
hyps.states[k] = d.extend_state(hyps.states[k])
|
espnet/nets/beam_search.py
ADDED
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Beam search module."""
|
2 |
+
|
3 |
+
from itertools import chain
|
4 |
+
import logging
|
5 |
+
from typing import Any
|
6 |
+
from typing import Dict
|
7 |
+
from typing import List
|
8 |
+
from typing import NamedTuple
|
9 |
+
from typing import Tuple
|
10 |
+
from typing import Union
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from espnet.nets.e2e_asr_common import end_detect
|
15 |
+
from espnet.nets.scorer_interface import PartialScorerInterface
|
16 |
+
from espnet.nets.scorer_interface import ScorerInterface
|
17 |
+
|
18 |
+
|
19 |
+
class Hypothesis(NamedTuple):
|
20 |
+
"""Hypothesis data type."""
|
21 |
+
|
22 |
+
yseq: torch.Tensor
|
23 |
+
score: Union[float, torch.Tensor] = 0
|
24 |
+
scores: Dict[str, Union[float, torch.Tensor]] = dict()
|
25 |
+
states: Dict[str, Any] = dict()
|
26 |
+
|
27 |
+
def asdict(self) -> dict:
|
28 |
+
"""Convert data to JSON-friendly dict."""
|
29 |
+
return self._replace(
|
30 |
+
yseq=self.yseq.tolist(),
|
31 |
+
score=float(self.score),
|
32 |
+
scores={k: float(v) for k, v in self.scores.items()},
|
33 |
+
)._asdict()
|
34 |
+
|
35 |
+
|
36 |
+
class BeamSearch(torch.nn.Module):
|
37 |
+
"""Beam search implementation."""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
scorers: Dict[str, ScorerInterface],
|
42 |
+
weights: Dict[str, float],
|
43 |
+
beam_size: int,
|
44 |
+
vocab_size: int,
|
45 |
+
sos: int,
|
46 |
+
eos: int,
|
47 |
+
token_list: List[str] = None,
|
48 |
+
pre_beam_ratio: float = 1.5,
|
49 |
+
pre_beam_score_key: str = None,
|
50 |
+
):
|
51 |
+
"""Initialize beam search.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
55 |
+
e.g., Decoder, CTCPrefixScorer, LM
|
56 |
+
The scorer will be ignored if it is `None`
|
57 |
+
weights (dict[str, float]): Dict of weights for each scorers
|
58 |
+
The scorer will be ignored if its weight is 0
|
59 |
+
beam_size (int): The number of hypotheses kept during search
|
60 |
+
vocab_size (int): The number of vocabulary
|
61 |
+
sos (int): Start of sequence id
|
62 |
+
eos (int): End of sequence id
|
63 |
+
token_list (list[str]): List of tokens for debug log
|
64 |
+
pre_beam_score_key (str): key of scores to perform pre-beam search
|
65 |
+
pre_beam_ratio (float): beam size in the pre-beam search
|
66 |
+
will be `int(pre_beam_ratio * beam_size)`
|
67 |
+
|
68 |
+
"""
|
69 |
+
super().__init__()
|
70 |
+
# set scorers
|
71 |
+
self.weights = weights
|
72 |
+
self.scorers = dict()
|
73 |
+
self.full_scorers = dict()
|
74 |
+
self.part_scorers = dict()
|
75 |
+
# this module dict is required for recursive cast
|
76 |
+
# `self.to(device, dtype)` in `recog.py`
|
77 |
+
self.nn_dict = torch.nn.ModuleDict()
|
78 |
+
for k, v in scorers.items():
|
79 |
+
w = weights.get(k, 0)
|
80 |
+
if w == 0 or v is None:
|
81 |
+
continue
|
82 |
+
assert isinstance(
|
83 |
+
v, ScorerInterface
|
84 |
+
), f"{k} ({type(v)}) does not implement ScorerInterface"
|
85 |
+
self.scorers[k] = v
|
86 |
+
if isinstance(v, PartialScorerInterface):
|
87 |
+
self.part_scorers[k] = v
|
88 |
+
else:
|
89 |
+
self.full_scorers[k] = v
|
90 |
+
if isinstance(v, torch.nn.Module):
|
91 |
+
self.nn_dict[k] = v
|
92 |
+
|
93 |
+
# set configurations
|
94 |
+
self.sos = sos
|
95 |
+
self.eos = eos
|
96 |
+
self.token_list = token_list
|
97 |
+
self.pre_beam_size = int(pre_beam_ratio * beam_size)
|
98 |
+
self.beam_size = beam_size
|
99 |
+
self.n_vocab = vocab_size
|
100 |
+
if (
|
101 |
+
pre_beam_score_key is not None
|
102 |
+
and pre_beam_score_key != "full"
|
103 |
+
and pre_beam_score_key not in self.full_scorers
|
104 |
+
):
|
105 |
+
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
|
106 |
+
self.pre_beam_score_key = pre_beam_score_key
|
107 |
+
self.do_pre_beam = (
|
108 |
+
self.pre_beam_score_key is not None
|
109 |
+
and self.pre_beam_size < self.n_vocab
|
110 |
+
and len(self.part_scorers) > 0
|
111 |
+
)
|
112 |
+
|
113 |
+
def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
|
114 |
+
"""Get an initial hypothesis data.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
x (torch.Tensor): The encoder output feature
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Hypothesis: The initial hypothesis.
|
121 |
+
|
122 |
+
"""
|
123 |
+
init_states = dict()
|
124 |
+
init_scores = dict()
|
125 |
+
for k, d in self.scorers.items():
|
126 |
+
init_states[k] = d.init_state(x)
|
127 |
+
init_scores[k] = 0.0
|
128 |
+
return [
|
129 |
+
Hypothesis(
|
130 |
+
score=0.0,
|
131 |
+
scores=init_scores,
|
132 |
+
states=init_states,
|
133 |
+
yseq=torch.tensor([self.sos], device=x.device),
|
134 |
+
)
|
135 |
+
]
|
136 |
+
|
137 |
+
@staticmethod
|
138 |
+
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
|
139 |
+
"""Append new token to prefix tokens.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
xs (torch.Tensor): The prefix token
|
143 |
+
x (int): The new token to append
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
|
147 |
+
|
148 |
+
"""
|
149 |
+
x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
|
150 |
+
return torch.cat((xs, x))
|
151 |
+
|
152 |
+
def score_full(
|
153 |
+
self, hyp: Hypothesis, x: torch.Tensor
|
154 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
155 |
+
"""Score new hypothesis by `self.full_scorers`.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
159 |
+
x (torch.Tensor): Corresponding input feature
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
163 |
+
score dict of `hyp` that has string keys of `self.full_scorers`
|
164 |
+
and tensor score values of shape: `(self.n_vocab,)`,
|
165 |
+
and state dict that has string keys
|
166 |
+
and state values of `self.full_scorers`
|
167 |
+
|
168 |
+
"""
|
169 |
+
scores = dict()
|
170 |
+
states = dict()
|
171 |
+
for k, d in self.full_scorers.items():
|
172 |
+
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
|
173 |
+
return scores, states
|
174 |
+
|
175 |
+
def score_partial(
|
176 |
+
self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
|
177 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
178 |
+
"""Score new hypothesis by `self.part_scorers`.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
182 |
+
ids (torch.Tensor): 1D tensor of new partial tokens to score
|
183 |
+
x (torch.Tensor): Corresponding input feature
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
187 |
+
score dict of `hyp` that has string keys of `self.part_scorers`
|
188 |
+
and tensor score values of shape: `(len(ids),)`,
|
189 |
+
and state dict that has string keys
|
190 |
+
and state values of `self.part_scorers`
|
191 |
+
|
192 |
+
"""
|
193 |
+
scores = dict()
|
194 |
+
states = dict()
|
195 |
+
for k, d in self.part_scorers.items():
|
196 |
+
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
|
197 |
+
return scores, states
|
198 |
+
|
199 |
+
def beam(
|
200 |
+
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
201 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
202 |
+
"""Compute topk full token ids and partial token ids.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
206 |
+
Its shape is `(self.n_vocab,)`.
|
207 |
+
ids (torch.Tensor): The partial token ids to compute topk
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
Tuple[torch.Tensor, torch.Tensor]:
|
211 |
+
The topk full token ids and partial token ids.
|
212 |
+
Their shapes are `(self.beam_size,)`
|
213 |
+
|
214 |
+
"""
|
215 |
+
# no pre beam performed
|
216 |
+
if weighted_scores.size(0) == ids.size(0):
|
217 |
+
top_ids = weighted_scores.topk(self.beam_size)[1]
|
218 |
+
return top_ids, top_ids
|
219 |
+
|
220 |
+
# mask pruned in pre-beam not to select in topk
|
221 |
+
tmp = weighted_scores[ids]
|
222 |
+
weighted_scores[:] = -float("inf")
|
223 |
+
weighted_scores[ids] = tmp
|
224 |
+
top_ids = weighted_scores.topk(self.beam_size)[1]
|
225 |
+
local_ids = weighted_scores[ids].topk(self.beam_size)[1]
|
226 |
+
return top_ids, local_ids
|
227 |
+
|
228 |
+
@staticmethod
|
229 |
+
def merge_scores(
|
230 |
+
prev_scores: Dict[str, float],
|
231 |
+
next_full_scores: Dict[str, torch.Tensor],
|
232 |
+
full_idx: int,
|
233 |
+
next_part_scores: Dict[str, torch.Tensor],
|
234 |
+
part_idx: int,
|
235 |
+
) -> Dict[str, torch.Tensor]:
|
236 |
+
"""Merge scores for new hypothesis.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
prev_scores (Dict[str, float]):
|
240 |
+
The previous hypothesis scores by `self.scorers`
|
241 |
+
next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
|
242 |
+
full_idx (int): The next token id for `next_full_scores`
|
243 |
+
next_part_scores (Dict[str, torch.Tensor]):
|
244 |
+
scores of partial tokens by `self.part_scorers`
|
245 |
+
part_idx (int): The new token id for `next_part_scores`
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
Dict[str, torch.Tensor]: The new score dict.
|
249 |
+
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
250 |
+
Its values are scalar tensors by the scorers.
|
251 |
+
|
252 |
+
"""
|
253 |
+
new_scores = dict()
|
254 |
+
for k, v in next_full_scores.items():
|
255 |
+
new_scores[k] = prev_scores[k] + v[full_idx]
|
256 |
+
for k, v in next_part_scores.items():
|
257 |
+
new_scores[k] = prev_scores[k] + v[part_idx]
|
258 |
+
return new_scores
|
259 |
+
|
260 |
+
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
261 |
+
"""Merge states for new hypothesis.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
states: states of `self.full_scorers`
|
265 |
+
part_states: states of `self.part_scorers`
|
266 |
+
part_idx (int): The new token id for `part_scores`
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
Dict[str, torch.Tensor]: The new score dict.
|
270 |
+
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
271 |
+
Its values are states of the scorers.
|
272 |
+
|
273 |
+
"""
|
274 |
+
new_states = dict()
|
275 |
+
for k, v in states.items():
|
276 |
+
new_states[k] = v
|
277 |
+
for k, d in self.part_scorers.items():
|
278 |
+
new_states[k] = d.select_state(part_states[k], part_idx)
|
279 |
+
return new_states
|
280 |
+
|
281 |
+
def search(
|
282 |
+
self, running_hyps: List[Hypothesis], x: torch.Tensor
|
283 |
+
) -> List[Hypothesis]:
|
284 |
+
"""Search new tokens for running hypotheses and encoded speech x.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
running_hyps (List[Hypothesis]): Running hypotheses on beam
|
288 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
List[Hypotheses]: Best sorted hypotheses
|
292 |
+
|
293 |
+
"""
|
294 |
+
best_hyps = []
|
295 |
+
part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
|
296 |
+
for hyp in running_hyps:
|
297 |
+
# scoring
|
298 |
+
weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
|
299 |
+
scores, states = self.score_full(hyp, x)
|
300 |
+
for k in self.full_scorers:
|
301 |
+
weighted_scores += self.weights[k] * scores[k]
|
302 |
+
# partial scoring
|
303 |
+
if self.do_pre_beam:
|
304 |
+
pre_beam_scores = (
|
305 |
+
weighted_scores
|
306 |
+
if self.pre_beam_score_key == "full"
|
307 |
+
else scores[self.pre_beam_score_key]
|
308 |
+
)
|
309 |
+
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
|
310 |
+
part_scores, part_states = self.score_partial(hyp, part_ids, x)
|
311 |
+
for k in self.part_scorers:
|
312 |
+
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
|
313 |
+
# add previous hyp score
|
314 |
+
weighted_scores += hyp.score
|
315 |
+
|
316 |
+
# update hyps
|
317 |
+
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
|
318 |
+
# will be (2 x beam at most)
|
319 |
+
best_hyps.append(
|
320 |
+
Hypothesis(
|
321 |
+
score=weighted_scores[j],
|
322 |
+
yseq=self.append_token(hyp.yseq, j),
|
323 |
+
scores=self.merge_scores(
|
324 |
+
hyp.scores, scores, j, part_scores, part_j
|
325 |
+
),
|
326 |
+
states=self.merge_states(states, part_states, part_j),
|
327 |
+
)
|
328 |
+
)
|
329 |
+
|
330 |
+
# sort and prune 2 x beam -> beam
|
331 |
+
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
|
332 |
+
: min(len(best_hyps), self.beam_size)
|
333 |
+
]
|
334 |
+
return best_hyps
|
335 |
+
|
336 |
+
def forward(
|
337 |
+
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
|
338 |
+
) -> List[Hypothesis]:
|
339 |
+
"""Perform beam search.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
343 |
+
maxlenratio (float): Input length ratio to obtain max output length.
|
344 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
345 |
+
to automatically find maximum hypothesis lengths
|
346 |
+
minlenratio (float): Input length ratio to obtain min output length.
|
347 |
+
|
348 |
+
Returns:
|
349 |
+
list[Hypothesis]: N-best decoding results
|
350 |
+
|
351 |
+
"""
|
352 |
+
# set length bounds
|
353 |
+
if maxlenratio == 0:
|
354 |
+
maxlen = x.shape[0]
|
355 |
+
else:
|
356 |
+
maxlen = max(1, int(maxlenratio * x.size(0)))
|
357 |
+
minlen = int(minlenratio * x.size(0))
|
358 |
+
logging.info("decoder input length: " + str(x.shape[0]))
|
359 |
+
logging.info("max output length: " + str(maxlen))
|
360 |
+
logging.info("min output length: " + str(minlen))
|
361 |
+
|
362 |
+
# main loop of prefix search
|
363 |
+
running_hyps = self.init_hyp(x)
|
364 |
+
ended_hyps = []
|
365 |
+
for i in range(maxlen):
|
366 |
+
logging.debug("position " + str(i))
|
367 |
+
best = self.search(running_hyps, x)
|
368 |
+
# post process of one iteration
|
369 |
+
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
|
370 |
+
# end detection
|
371 |
+
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
|
372 |
+
logging.info(f"end detected at {i}")
|
373 |
+
break
|
374 |
+
if len(running_hyps) == 0:
|
375 |
+
logging.info("no hypothesis. Finish decoding.")
|
376 |
+
break
|
377 |
+
else:
|
378 |
+
logging.debug(f"remained hypotheses: {len(running_hyps)}")
|
379 |
+
|
380 |
+
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
|
381 |
+
# check the number of hypotheses reaching to eos
|
382 |
+
if len(nbest_hyps) == 0:
|
383 |
+
logging.warning(
|
384 |
+
"there is no N-best results, perform recognition "
|
385 |
+
"again with smaller minlenratio."
|
386 |
+
)
|
387 |
+
return (
|
388 |
+
[]
|
389 |
+
if minlenratio < 0.1
|
390 |
+
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
|
391 |
+
)
|
392 |
+
|
393 |
+
# report the best result
|
394 |
+
best = nbest_hyps[0]
|
395 |
+
for k, v in best.scores.items():
|
396 |
+
logging.info(
|
397 |
+
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
|
398 |
+
)
|
399 |
+
logging.info(f"total log probability: {best.score:.2f}")
|
400 |
+
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
|
401 |
+
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
|
402 |
+
if self.token_list is not None:
|
403 |
+
logging.info(
|
404 |
+
"best hypo: "
|
405 |
+
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
|
406 |
+
+ "\n"
|
407 |
+
)
|
408 |
+
return nbest_hyps
|
409 |
+
|
410 |
+
def post_process(
|
411 |
+
self,
|
412 |
+
i: int,
|
413 |
+
maxlen: int,
|
414 |
+
maxlenratio: float,
|
415 |
+
running_hyps: List[Hypothesis],
|
416 |
+
ended_hyps: List[Hypothesis],
|
417 |
+
) -> List[Hypothesis]:
|
418 |
+
"""Perform post-processing of beam search iterations.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
i (int): The length of hypothesis tokens.
|
422 |
+
maxlen (int): The maximum length of tokens in beam search.
|
423 |
+
maxlenratio (int): The maximum length ratio in beam search.
|
424 |
+
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
|
425 |
+
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
List[Hypothesis]: The new running hypotheses.
|
429 |
+
|
430 |
+
"""
|
431 |
+
logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
|
432 |
+
if self.token_list is not None:
|
433 |
+
logging.debug(
|
434 |
+
"best hypo: "
|
435 |
+
+ "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
|
436 |
+
)
|
437 |
+
# add eos in the final loop to avoid that there are no ended hyps
|
438 |
+
if i == maxlen - 1:
|
439 |
+
logging.info("adding <eos> in the last position in the loop")
|
440 |
+
running_hyps = [
|
441 |
+
h._replace(yseq=self.append_token(h.yseq, self.eos))
|
442 |
+
for h in running_hyps
|
443 |
+
]
|
444 |
+
|
445 |
+
# add ended hypotheses to a final list, and removed them from current hypotheses
|
446 |
+
# (this will be a problem, number of hyps < beam)
|
447 |
+
remained_hyps = []
|
448 |
+
for hyp in running_hyps:
|
449 |
+
if hyp.yseq[-1] == self.eos:
|
450 |
+
# e.g., Word LM needs to add final <eos> score
|
451 |
+
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
|
452 |
+
s = d.final_score(hyp.states[k])
|
453 |
+
hyp.scores[k] += s
|
454 |
+
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
|
455 |
+
ended_hyps.append(hyp)
|
456 |
+
else:
|
457 |
+
remained_hyps.append(hyp)
|
458 |
+
return remained_hyps
|
459 |
+
|
460 |
+
|
461 |
+
def beam_search(
|
462 |
+
x: torch.Tensor,
|
463 |
+
sos: int,
|
464 |
+
eos: int,
|
465 |
+
beam_size: int,
|
466 |
+
vocab_size: int,
|
467 |
+
scorers: Dict[str, ScorerInterface],
|
468 |
+
weights: Dict[str, float],
|
469 |
+
token_list: List[str] = None,
|
470 |
+
maxlenratio: float = 0.0,
|
471 |
+
minlenratio: float = 0.0,
|
472 |
+
pre_beam_ratio: float = 1.5,
|
473 |
+
pre_beam_score_key: str = "full",
|
474 |
+
) -> list:
|
475 |
+
"""Perform beam search with scorers.
|
476 |
+
|
477 |
+
Args:
|
478 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
479 |
+
sos (int): Start of sequence id
|
480 |
+
eos (int): End of sequence id
|
481 |
+
beam_size (int): The number of hypotheses kept during search
|
482 |
+
vocab_size (int): The number of vocabulary
|
483 |
+
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
484 |
+
e.g., Decoder, CTCPrefixScorer, LM
|
485 |
+
The scorer will be ignored if it is `None`
|
486 |
+
weights (dict[str, float]): Dict of weights for each scorers
|
487 |
+
The scorer will be ignored if its weight is 0
|
488 |
+
token_list (list[str]): List of tokens for debug log
|
489 |
+
maxlenratio (float): Input length ratio to obtain max output length.
|
490 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
491 |
+
to automatically find maximum hypothesis lengths
|
492 |
+
minlenratio (float): Input length ratio to obtain min output length.
|
493 |
+
pre_beam_score_key (str): key of scores to perform pre-beam search
|
494 |
+
pre_beam_ratio (float): beam size in the pre-beam search
|
495 |
+
will be `int(pre_beam_ratio * beam_size)`
|
496 |
+
|
497 |
+
Returns:
|
498 |
+
list: N-best decoding results
|
499 |
+
|
500 |
+
"""
|
501 |
+
ret = BeamSearch(
|
502 |
+
scorers,
|
503 |
+
weights,
|
504 |
+
beam_size=beam_size,
|
505 |
+
vocab_size=vocab_size,
|
506 |
+
pre_beam_ratio=pre_beam_ratio,
|
507 |
+
pre_beam_score_key=pre_beam_score_key,
|
508 |
+
sos=sos,
|
509 |
+
eos=eos,
|
510 |
+
token_list=token_list,
|
511 |
+
).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
|
512 |
+
return [h.asdict() for h in ret]
|
espnet/nets/beam_search_transducer.py
ADDED
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Search algorithms for transducer models."""
|
2 |
+
|
3 |
+
from typing import List
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from espnet.nets.pytorch_backend.transducer.utils import create_lm_batch_state
|
10 |
+
from espnet.nets.pytorch_backend.transducer.utils import init_lm_state
|
11 |
+
from espnet.nets.pytorch_backend.transducer.utils import is_prefix
|
12 |
+
from espnet.nets.pytorch_backend.transducer.utils import recombine_hyps
|
13 |
+
from espnet.nets.pytorch_backend.transducer.utils import select_lm_state
|
14 |
+
from espnet.nets.pytorch_backend.transducer.utils import substract
|
15 |
+
from espnet.nets.transducer_decoder_interface import Hypothesis
|
16 |
+
from espnet.nets.transducer_decoder_interface import NSCHypothesis
|
17 |
+
from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface
|
18 |
+
|
19 |
+
|
20 |
+
class BeamSearchTransducer:
|
21 |
+
"""Beam search implementation for transducer."""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
decoder: Union[TransducerDecoderInterface, torch.nn.Module],
|
26 |
+
joint_network: torch.nn.Module,
|
27 |
+
beam_size: int,
|
28 |
+
lm: torch.nn.Module = None,
|
29 |
+
lm_weight: float = 0.1,
|
30 |
+
search_type: str = "default",
|
31 |
+
max_sym_exp: int = 2,
|
32 |
+
u_max: int = 50,
|
33 |
+
nstep: int = 1,
|
34 |
+
prefix_alpha: int = 1,
|
35 |
+
score_norm: bool = True,
|
36 |
+
nbest: int = 1,
|
37 |
+
):
|
38 |
+
"""Initialize transducer beam search.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
decoder: Decoder class to use
|
42 |
+
joint_network: Joint Network class
|
43 |
+
beam_size: Number of hypotheses kept during search
|
44 |
+
lm: LM class to use
|
45 |
+
lm_weight: lm weight for soft fusion
|
46 |
+
search_type: type of algorithm to use for search
|
47 |
+
max_sym_exp: number of maximum symbol expansions at each time step ("tsd")
|
48 |
+
u_max: maximum output sequence length ("alsd")
|
49 |
+
nstep: number of maximum expansion steps at each time step ("nsc")
|
50 |
+
prefix_alpha: maximum prefix length in prefix search ("nsc")
|
51 |
+
score_norm: normalize final scores by length ("default")
|
52 |
+
nbest: number of returned final hypothesis
|
53 |
+
"""
|
54 |
+
self.decoder = decoder
|
55 |
+
self.joint_network = joint_network
|
56 |
+
|
57 |
+
self.beam_size = beam_size
|
58 |
+
self.hidden_size = decoder.dunits
|
59 |
+
self.vocab_size = decoder.odim
|
60 |
+
self.blank = decoder.blank
|
61 |
+
|
62 |
+
if self.beam_size <= 1:
|
63 |
+
self.search_algorithm = self.greedy_search
|
64 |
+
elif search_type == "default":
|
65 |
+
self.search_algorithm = self.default_beam_search
|
66 |
+
elif search_type == "tsd":
|
67 |
+
self.search_algorithm = self.time_sync_decoding
|
68 |
+
elif search_type == "alsd":
|
69 |
+
self.search_algorithm = self.align_length_sync_decoding
|
70 |
+
elif search_type == "nsc":
|
71 |
+
self.search_algorithm = self.nsc_beam_search
|
72 |
+
else:
|
73 |
+
raise NotImplementedError
|
74 |
+
|
75 |
+
self.lm = lm
|
76 |
+
self.lm_weight = lm_weight
|
77 |
+
|
78 |
+
if lm is not None:
|
79 |
+
self.use_lm = True
|
80 |
+
self.is_wordlm = True if hasattr(lm.predictor, "wordlm") else False
|
81 |
+
self.lm_predictor = lm.predictor.wordlm if self.is_wordlm else lm.predictor
|
82 |
+
self.lm_layers = len(self.lm_predictor.rnn)
|
83 |
+
else:
|
84 |
+
self.use_lm = False
|
85 |
+
|
86 |
+
self.max_sym_exp = max_sym_exp
|
87 |
+
self.u_max = u_max
|
88 |
+
self.nstep = nstep
|
89 |
+
self.prefix_alpha = prefix_alpha
|
90 |
+
self.score_norm = score_norm
|
91 |
+
|
92 |
+
self.nbest = nbest
|
93 |
+
|
94 |
+
def __call__(self, h: torch.Tensor) -> Union[List[Hypothesis], List[NSCHypothesis]]:
|
95 |
+
"""Perform beam search.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
h: Encoded speech features (T_max, D_enc)
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
nbest_hyps: N-best decoding results
|
102 |
+
|
103 |
+
"""
|
104 |
+
self.decoder.set_device(h.device)
|
105 |
+
|
106 |
+
if not hasattr(self.decoder, "decoders"):
|
107 |
+
self.decoder.set_data_type(h.dtype)
|
108 |
+
|
109 |
+
nbest_hyps = self.search_algorithm(h)
|
110 |
+
|
111 |
+
return nbest_hyps
|
112 |
+
|
113 |
+
def sort_nbest(
|
114 |
+
self, hyps: Union[List[Hypothesis], List[NSCHypothesis]]
|
115 |
+
) -> Union[List[Hypothesis], List[NSCHypothesis]]:
|
116 |
+
"""Sort hypotheses by score or score given sequence length.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
hyps: list of hypotheses
|
120 |
+
|
121 |
+
Return:
|
122 |
+
hyps: sorted list of hypotheses
|
123 |
+
|
124 |
+
"""
|
125 |
+
if self.score_norm:
|
126 |
+
hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
|
127 |
+
else:
|
128 |
+
hyps.sort(key=lambda x: x.score, reverse=True)
|
129 |
+
|
130 |
+
return hyps[: self.nbest]
|
131 |
+
|
132 |
+
def greedy_search(self, h: torch.Tensor) -> List[Hypothesis]:
|
133 |
+
"""Greedy search implementation for transformer-transducer.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
h: Encoded speech features (T_max, D_enc)
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
hyp: 1-best decoding results
|
140 |
+
|
141 |
+
"""
|
142 |
+
dec_state = self.decoder.init_state(1)
|
143 |
+
|
144 |
+
hyp = Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state)
|
145 |
+
cache = {}
|
146 |
+
|
147 |
+
y, state, _ = self.decoder.score(hyp, cache)
|
148 |
+
|
149 |
+
for i, hi in enumerate(h):
|
150 |
+
ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1)
|
151 |
+
logp, pred = torch.max(ytu, dim=-1)
|
152 |
+
|
153 |
+
if pred != self.blank:
|
154 |
+
hyp.yseq.append(int(pred))
|
155 |
+
hyp.score += float(logp)
|
156 |
+
|
157 |
+
hyp.dec_state = state
|
158 |
+
|
159 |
+
y, state, _ = self.decoder.score(hyp, cache)
|
160 |
+
|
161 |
+
return [hyp]
|
162 |
+
|
163 |
+
def default_beam_search(self, h: torch.Tensor) -> List[Hypothesis]:
|
164 |
+
"""Beam search implementation.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
x: Encoded speech features (T_max, D_enc)
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
nbest_hyps: N-best decoding results
|
171 |
+
|
172 |
+
"""
|
173 |
+
beam = min(self.beam_size, self.vocab_size)
|
174 |
+
beam_k = min(beam, (self.vocab_size - 1))
|
175 |
+
|
176 |
+
dec_state = self.decoder.init_state(1)
|
177 |
+
|
178 |
+
kept_hyps = [Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state)]
|
179 |
+
cache = {}
|
180 |
+
|
181 |
+
for hi in h:
|
182 |
+
hyps = kept_hyps
|
183 |
+
kept_hyps = []
|
184 |
+
|
185 |
+
while True:
|
186 |
+
max_hyp = max(hyps, key=lambda x: x.score)
|
187 |
+
hyps.remove(max_hyp)
|
188 |
+
|
189 |
+
y, state, lm_tokens = self.decoder.score(max_hyp, cache)
|
190 |
+
|
191 |
+
ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1)
|
192 |
+
top_k = ytu[1:].topk(beam_k, dim=-1)
|
193 |
+
|
194 |
+
kept_hyps.append(
|
195 |
+
Hypothesis(
|
196 |
+
score=(max_hyp.score + float(ytu[0:1])),
|
197 |
+
yseq=max_hyp.yseq[:],
|
198 |
+
dec_state=max_hyp.dec_state,
|
199 |
+
lm_state=max_hyp.lm_state,
|
200 |
+
)
|
201 |
+
)
|
202 |
+
|
203 |
+
if self.use_lm:
|
204 |
+
lm_state, lm_scores = self.lm.predict(max_hyp.lm_state, lm_tokens)
|
205 |
+
else:
|
206 |
+
lm_state = max_hyp.lm_state
|
207 |
+
|
208 |
+
for logp, k in zip(*top_k):
|
209 |
+
score = max_hyp.score + float(logp)
|
210 |
+
|
211 |
+
if self.use_lm:
|
212 |
+
score += self.lm_weight * lm_scores[0][k + 1]
|
213 |
+
|
214 |
+
hyps.append(
|
215 |
+
Hypothesis(
|
216 |
+
score=score,
|
217 |
+
yseq=max_hyp.yseq[:] + [int(k + 1)],
|
218 |
+
dec_state=state,
|
219 |
+
lm_state=lm_state,
|
220 |
+
)
|
221 |
+
)
|
222 |
+
|
223 |
+
hyps_max = float(max(hyps, key=lambda x: x.score).score)
|
224 |
+
kept_most_prob = sorted(
|
225 |
+
[hyp for hyp in kept_hyps if hyp.score > hyps_max],
|
226 |
+
key=lambda x: x.score,
|
227 |
+
)
|
228 |
+
if len(kept_most_prob) >= beam:
|
229 |
+
kept_hyps = kept_most_prob
|
230 |
+
break
|
231 |
+
|
232 |
+
return self.sort_nbest(kept_hyps)
|
233 |
+
|
234 |
+
def time_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]:
|
235 |
+
"""Time synchronous beam search implementation.
|
236 |
+
|
237 |
+
Based on https://ieeexplore.ieee.org/document/9053040
|
238 |
+
|
239 |
+
Args:
|
240 |
+
h: Encoded speech features (T_max, D_enc)
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
nbest_hyps: N-best decoding results
|
244 |
+
|
245 |
+
"""
|
246 |
+
beam = min(self.beam_size, self.vocab_size)
|
247 |
+
|
248 |
+
beam_state = self.decoder.init_state(beam)
|
249 |
+
|
250 |
+
B = [
|
251 |
+
Hypothesis(
|
252 |
+
yseq=[self.blank],
|
253 |
+
score=0.0,
|
254 |
+
dec_state=self.decoder.select_state(beam_state, 0),
|
255 |
+
)
|
256 |
+
]
|
257 |
+
cache = {}
|
258 |
+
|
259 |
+
if self.use_lm and not self.is_wordlm:
|
260 |
+
B[0].lm_state = init_lm_state(self.lm_predictor)
|
261 |
+
|
262 |
+
for hi in h:
|
263 |
+
A = []
|
264 |
+
C = B
|
265 |
+
|
266 |
+
h_enc = hi.unsqueeze(0)
|
267 |
+
|
268 |
+
for v in range(self.max_sym_exp):
|
269 |
+
D = []
|
270 |
+
|
271 |
+
beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
|
272 |
+
C,
|
273 |
+
beam_state,
|
274 |
+
cache,
|
275 |
+
self.use_lm,
|
276 |
+
)
|
277 |
+
|
278 |
+
beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1)
|
279 |
+
beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)
|
280 |
+
|
281 |
+
seq_A = [h.yseq for h in A]
|
282 |
+
|
283 |
+
for i, hyp in enumerate(C):
|
284 |
+
if hyp.yseq not in seq_A:
|
285 |
+
A.append(
|
286 |
+
Hypothesis(
|
287 |
+
score=(hyp.score + float(beam_logp[i, 0])),
|
288 |
+
yseq=hyp.yseq[:],
|
289 |
+
dec_state=hyp.dec_state,
|
290 |
+
lm_state=hyp.lm_state,
|
291 |
+
)
|
292 |
+
)
|
293 |
+
else:
|
294 |
+
dict_pos = seq_A.index(hyp.yseq)
|
295 |
+
|
296 |
+
A[dict_pos].score = np.logaddexp(
|
297 |
+
A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
|
298 |
+
)
|
299 |
+
|
300 |
+
if v < (self.max_sym_exp - 1):
|
301 |
+
if self.use_lm:
|
302 |
+
beam_lm_states = create_lm_batch_state(
|
303 |
+
[c.lm_state for c in C], self.lm_layers, self.is_wordlm
|
304 |
+
)
|
305 |
+
|
306 |
+
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
|
307 |
+
beam_lm_states, beam_lm_tokens, len(C)
|
308 |
+
)
|
309 |
+
|
310 |
+
for i, hyp in enumerate(C):
|
311 |
+
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
|
312 |
+
new_hyp = Hypothesis(
|
313 |
+
score=(hyp.score + float(logp)),
|
314 |
+
yseq=(hyp.yseq + [int(k)]),
|
315 |
+
dec_state=self.decoder.select_state(beam_state, i),
|
316 |
+
lm_state=hyp.lm_state,
|
317 |
+
)
|
318 |
+
|
319 |
+
if self.use_lm:
|
320 |
+
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
|
321 |
+
|
322 |
+
new_hyp.lm_state = select_lm_state(
|
323 |
+
beam_lm_states, i, self.lm_layers, self.is_wordlm
|
324 |
+
)
|
325 |
+
|
326 |
+
D.append(new_hyp)
|
327 |
+
|
328 |
+
C = sorted(D, key=lambda x: x.score, reverse=True)[:beam]
|
329 |
+
|
330 |
+
B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
|
331 |
+
|
332 |
+
return self.sort_nbest(B)
|
333 |
+
|
334 |
+
def align_length_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]:
|
335 |
+
"""Alignment-length synchronous beam search implementation.
|
336 |
+
|
337 |
+
Based on https://ieeexplore.ieee.org/document/9053040
|
338 |
+
|
339 |
+
Args:
|
340 |
+
h: Encoded speech features (T_max, D_enc)
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
nbest_hyps: N-best decoding results
|
344 |
+
|
345 |
+
"""
|
346 |
+
beam = min(self.beam_size, self.vocab_size)
|
347 |
+
|
348 |
+
h_length = int(h.size(0))
|
349 |
+
u_max = min(self.u_max, (h_length - 1))
|
350 |
+
|
351 |
+
beam_state = self.decoder.init_state(beam)
|
352 |
+
|
353 |
+
B = [
|
354 |
+
Hypothesis(
|
355 |
+
yseq=[self.blank],
|
356 |
+
score=0.0,
|
357 |
+
dec_state=self.decoder.select_state(beam_state, 0),
|
358 |
+
)
|
359 |
+
]
|
360 |
+
final = []
|
361 |
+
cache = {}
|
362 |
+
|
363 |
+
if self.use_lm and not self.is_wordlm:
|
364 |
+
B[0].lm_state = init_lm_state(self.lm_predictor)
|
365 |
+
|
366 |
+
for i in range(h_length + u_max):
|
367 |
+
A = []
|
368 |
+
|
369 |
+
B_ = []
|
370 |
+
h_states = []
|
371 |
+
for hyp in B:
|
372 |
+
u = len(hyp.yseq) - 1
|
373 |
+
t = i - u + 1
|
374 |
+
|
375 |
+
if t > (h_length - 1):
|
376 |
+
continue
|
377 |
+
|
378 |
+
B_.append(hyp)
|
379 |
+
h_states.append((t, h[t]))
|
380 |
+
|
381 |
+
if B_:
|
382 |
+
beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
|
383 |
+
B_,
|
384 |
+
beam_state,
|
385 |
+
cache,
|
386 |
+
self.use_lm,
|
387 |
+
)
|
388 |
+
|
389 |
+
h_enc = torch.stack([h[1] for h in h_states])
|
390 |
+
|
391 |
+
beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1)
|
392 |
+
beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)
|
393 |
+
|
394 |
+
if self.use_lm:
|
395 |
+
beam_lm_states = create_lm_batch_state(
|
396 |
+
[b.lm_state for b in B_], self.lm_layers, self.is_wordlm
|
397 |
+
)
|
398 |
+
|
399 |
+
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
|
400 |
+
beam_lm_states, beam_lm_tokens, len(B_)
|
401 |
+
)
|
402 |
+
|
403 |
+
for i, hyp in enumerate(B_):
|
404 |
+
new_hyp = Hypothesis(
|
405 |
+
score=(hyp.score + float(beam_logp[i, 0])),
|
406 |
+
yseq=hyp.yseq[:],
|
407 |
+
dec_state=hyp.dec_state,
|
408 |
+
lm_state=hyp.lm_state,
|
409 |
+
)
|
410 |
+
|
411 |
+
A.append(new_hyp)
|
412 |
+
|
413 |
+
if h_states[i][0] == (h_length - 1):
|
414 |
+
final.append(new_hyp)
|
415 |
+
|
416 |
+
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
|
417 |
+
new_hyp = Hypothesis(
|
418 |
+
score=(hyp.score + float(logp)),
|
419 |
+
yseq=(hyp.yseq[:] + [int(k)]),
|
420 |
+
dec_state=self.decoder.select_state(beam_state, i),
|
421 |
+
lm_state=hyp.lm_state,
|
422 |
+
)
|
423 |
+
|
424 |
+
if self.use_lm:
|
425 |
+
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
|
426 |
+
|
427 |
+
new_hyp.lm_state = select_lm_state(
|
428 |
+
beam_lm_states, i, self.lm_layers, self.is_wordlm
|
429 |
+
)
|
430 |
+
|
431 |
+
A.append(new_hyp)
|
432 |
+
|
433 |
+
B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
|
434 |
+
B = recombine_hyps(B)
|
435 |
+
|
436 |
+
if final:
|
437 |
+
return self.sort_nbest(final)
|
438 |
+
else:
|
439 |
+
return B
|
440 |
+
|
441 |
+
def nsc_beam_search(self, h: torch.Tensor) -> List[NSCHypothesis]:
|
442 |
+
"""N-step constrained beam search implementation.
|
443 |
+
|
444 |
+
Based and modified from https://arxiv.org/pdf/2002.03577.pdf.
|
445 |
+
Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet
|
446 |
+
until further modifications.
|
447 |
+
|
448 |
+
Note: the algorithm is not in his "complete" form but works almost as
|
449 |
+
intended.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
h: Encoded speech features (T_max, D_enc)
|
453 |
+
|
454 |
+
Returns:
|
455 |
+
nbest_hyps: N-best decoding results
|
456 |
+
|
457 |
+
"""
|
458 |
+
beam = min(self.beam_size, self.vocab_size)
|
459 |
+
beam_k = min(beam, (self.vocab_size - 1))
|
460 |
+
|
461 |
+
beam_state = self.decoder.init_state(beam)
|
462 |
+
|
463 |
+
init_tokens = [
|
464 |
+
NSCHypothesis(
|
465 |
+
yseq=[self.blank],
|
466 |
+
score=0.0,
|
467 |
+
dec_state=self.decoder.select_state(beam_state, 0),
|
468 |
+
)
|
469 |
+
]
|
470 |
+
|
471 |
+
cache = {}
|
472 |
+
|
473 |
+
beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
|
474 |
+
init_tokens,
|
475 |
+
beam_state,
|
476 |
+
cache,
|
477 |
+
self.use_lm,
|
478 |
+
)
|
479 |
+
|
480 |
+
state = self.decoder.select_state(beam_state, 0)
|
481 |
+
|
482 |
+
if self.use_lm:
|
483 |
+
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
|
484 |
+
None, beam_lm_tokens, 1
|
485 |
+
)
|
486 |
+
lm_state = select_lm_state(
|
487 |
+
beam_lm_states, 0, self.lm_layers, self.is_wordlm
|
488 |
+
)
|
489 |
+
lm_scores = beam_lm_scores[0]
|
490 |
+
else:
|
491 |
+
lm_state = None
|
492 |
+
lm_scores = None
|
493 |
+
|
494 |
+
kept_hyps = [
|
495 |
+
NSCHypothesis(
|
496 |
+
yseq=[self.blank],
|
497 |
+
score=0.0,
|
498 |
+
dec_state=state,
|
499 |
+
y=[beam_y[0]],
|
500 |
+
lm_state=lm_state,
|
501 |
+
lm_scores=lm_scores,
|
502 |
+
)
|
503 |
+
]
|
504 |
+
|
505 |
+
for hi in h:
|
506 |
+
hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True)
|
507 |
+
kept_hyps = []
|
508 |
+
|
509 |
+
h_enc = hi.unsqueeze(0)
|
510 |
+
|
511 |
+
for j, hyp_j in enumerate(hyps[:-1]):
|
512 |
+
for hyp_i in hyps[(j + 1) :]:
|
513 |
+
curr_id = len(hyp_j.yseq)
|
514 |
+
next_id = len(hyp_i.yseq)
|
515 |
+
|
516 |
+
if (
|
517 |
+
is_prefix(hyp_j.yseq, hyp_i.yseq)
|
518 |
+
and (curr_id - next_id) <= self.prefix_alpha
|
519 |
+
):
|
520 |
+
ytu = torch.log_softmax(
|
521 |
+
self.joint_network(hi, hyp_i.y[-1]), dim=-1
|
522 |
+
)
|
523 |
+
|
524 |
+
curr_score = hyp_i.score + float(ytu[hyp_j.yseq[next_id]])
|
525 |
+
|
526 |
+
for k in range(next_id, (curr_id - 1)):
|
527 |
+
ytu = torch.log_softmax(
|
528 |
+
self.joint_network(hi, hyp_j.y[k]), dim=-1
|
529 |
+
)
|
530 |
+
|
531 |
+
curr_score += float(ytu[hyp_j.yseq[k + 1]])
|
532 |
+
|
533 |
+
hyp_j.score = np.logaddexp(hyp_j.score, curr_score)
|
534 |
+
|
535 |
+
S = []
|
536 |
+
V = []
|
537 |
+
for n in range(self.nstep):
|
538 |
+
beam_y = torch.stack([hyp.y[-1] for hyp in hyps])
|
539 |
+
|
540 |
+
beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1)
|
541 |
+
beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1)
|
542 |
+
|
543 |
+
for i, hyp in enumerate(hyps):
|
544 |
+
S.append(
|
545 |
+
NSCHypothesis(
|
546 |
+
yseq=hyp.yseq[:],
|
547 |
+
score=hyp.score + float(beam_logp[i, 0:1]),
|
548 |
+
y=hyp.y[:],
|
549 |
+
dec_state=hyp.dec_state,
|
550 |
+
lm_state=hyp.lm_state,
|
551 |
+
lm_scores=hyp.lm_scores,
|
552 |
+
)
|
553 |
+
)
|
554 |
+
|
555 |
+
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
|
556 |
+
score = hyp.score + float(logp)
|
557 |
+
|
558 |
+
if self.use_lm:
|
559 |
+
score += self.lm_weight * float(hyp.lm_scores[k])
|
560 |
+
|
561 |
+
V.append(
|
562 |
+
NSCHypothesis(
|
563 |
+
yseq=hyp.yseq[:] + [int(k)],
|
564 |
+
score=score,
|
565 |
+
y=hyp.y[:],
|
566 |
+
dec_state=hyp.dec_state,
|
567 |
+
lm_state=hyp.lm_state,
|
568 |
+
lm_scores=hyp.lm_scores,
|
569 |
+
)
|
570 |
+
)
|
571 |
+
|
572 |
+
V.sort(key=lambda x: x.score, reverse=True)
|
573 |
+
V = substract(V, hyps)[:beam]
|
574 |
+
|
575 |
+
beam_state = self.decoder.create_batch_states(
|
576 |
+
beam_state,
|
577 |
+
[v.dec_state for v in V],
|
578 |
+
[v.yseq for v in V],
|
579 |
+
)
|
580 |
+
beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
|
581 |
+
V,
|
582 |
+
beam_state,
|
583 |
+
cache,
|
584 |
+
self.use_lm,
|
585 |
+
)
|
586 |
+
|
587 |
+
if self.use_lm:
|
588 |
+
beam_lm_states = create_lm_batch_state(
|
589 |
+
[v.lm_state for v in V], self.lm_layers, self.is_wordlm
|
590 |
+
)
|
591 |
+
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
|
592 |
+
beam_lm_states, beam_lm_tokens, len(V)
|
593 |
+
)
|
594 |
+
|
595 |
+
if n < (self.nstep - 1):
|
596 |
+
for i, v in enumerate(V):
|
597 |
+
v.y.append(beam_y[i])
|
598 |
+
|
599 |
+
v.dec_state = self.decoder.select_state(beam_state, i)
|
600 |
+
|
601 |
+
if self.use_lm:
|
602 |
+
v.lm_state = select_lm_state(
|
603 |
+
beam_lm_states, i, self.lm_layers, self.is_wordlm
|
604 |
+
)
|
605 |
+
v.lm_scores = beam_lm_scores[i]
|
606 |
+
|
607 |
+
hyps = V[:]
|
608 |
+
else:
|
609 |
+
beam_logp = torch.log_softmax(
|
610 |
+
self.joint_network(h_enc, beam_y), dim=-1
|
611 |
+
)
|
612 |
+
|
613 |
+
for i, v in enumerate(V):
|
614 |
+
if self.nstep != 1:
|
615 |
+
v.score += float(beam_logp[i, 0])
|
616 |
+
|
617 |
+
v.y.append(beam_y[i])
|
618 |
+
|
619 |
+
v.dec_state = self.decoder.select_state(beam_state, i)
|
620 |
+
|
621 |
+
if self.use_lm:
|
622 |
+
v.lm_state = select_lm_state(
|
623 |
+
beam_lm_states, i, self.lm_layers, self.is_wordlm
|
624 |
+
)
|
625 |
+
v.lm_scores = beam_lm_scores[i]
|
626 |
+
|
627 |
+
kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam]
|
628 |
+
|
629 |
+
return self.sort_nbest(kept_hyps)
|
espnet/nets/chainer_backend/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/nets/chainer_backend/asr_interface.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""ASR Interface module."""
|
2 |
+
import chainer
|
3 |
+
|
4 |
+
from espnet.nets.asr_interface import ASRInterface
|
5 |
+
|
6 |
+
|
7 |
+
class ChainerASRInterface(ASRInterface, chainer.Chain):
|
8 |
+
"""ASR Interface for ESPnet model implementation."""
|
9 |
+
|
10 |
+
@staticmethod
|
11 |
+
def custom_converter(*args, **kw):
|
12 |
+
"""Get customconverter of the model (Chainer only)."""
|
13 |
+
raise NotImplementedError("custom converter method is not implemented")
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def custom_updater(*args, **kw):
|
17 |
+
"""Get custom_updater of the model (Chainer only)."""
|
18 |
+
raise NotImplementedError("custom updater method is not implemented")
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def custom_parallel_updater(*args, **kw):
|
22 |
+
"""Get custom_parallel_updater of the model (Chainer only)."""
|
23 |
+
raise NotImplementedError("custom parallel updater method is not implemented")
|
24 |
+
|
25 |
+
def get_total_subsampling_factor(self):
|
26 |
+
"""Get total subsampling factor."""
|
27 |
+
raise NotImplementedError(
|
28 |
+
"get_total_subsampling_factor method is not implemented"
|
29 |
+
)
|
espnet/nets/chainer_backend/ctc.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import chainer
|
4 |
+
from chainer import cuda
|
5 |
+
import chainer.functions as F
|
6 |
+
import chainer.links as L
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
class CTC(chainer.Chain):
|
11 |
+
"""Chainer implementation of ctc layer.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
odim (int): The output dimension.
|
15 |
+
eprojs (int | None): Dimension of input vectors from encoder.
|
16 |
+
dropout_rate (float): Dropout rate.
|
17 |
+
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, odim, eprojs, dropout_rate):
|
21 |
+
super(CTC, self).__init__()
|
22 |
+
self.dropout_rate = dropout_rate
|
23 |
+
self.loss = None
|
24 |
+
|
25 |
+
with self.init_scope():
|
26 |
+
self.ctc_lo = L.Linear(eprojs, odim)
|
27 |
+
|
28 |
+
def __call__(self, hs, ys):
|
29 |
+
"""CTC forward.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
hs (list of chainer.Variable | N-dimension array):
|
33 |
+
Input variable from encoder.
|
34 |
+
ys (list of chainer.Variable | N-dimension array):
|
35 |
+
Input variable of decoder.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
chainer.Variable: A variable holding a scalar value of the CTC loss.
|
39 |
+
|
40 |
+
"""
|
41 |
+
self.loss = None
|
42 |
+
ilens = [x.shape[0] for x in hs]
|
43 |
+
olens = [x.shape[0] for x in ys]
|
44 |
+
|
45 |
+
# zero padding for hs
|
46 |
+
y_hat = self.ctc_lo(
|
47 |
+
F.dropout(F.pad_sequence(hs), ratio=self.dropout_rate), n_batch_axes=2
|
48 |
+
)
|
49 |
+
y_hat = F.separate(y_hat, axis=1) # ilen list of batch x hdim
|
50 |
+
|
51 |
+
# zero padding for ys
|
52 |
+
y_true = F.pad_sequence(ys, padding=-1) # batch x olen
|
53 |
+
|
54 |
+
# get length info
|
55 |
+
input_length = chainer.Variable(self.xp.array(ilens, dtype=np.int32))
|
56 |
+
label_length = chainer.Variable(self.xp.array(olens, dtype=np.int32))
|
57 |
+
logging.info(
|
58 |
+
self.__class__.__name__ + " input lengths: " + str(input_length.data)
|
59 |
+
)
|
60 |
+
logging.info(
|
61 |
+
self.__class__.__name__ + " output lengths: " + str(label_length.data)
|
62 |
+
)
|
63 |
+
|
64 |
+
# get ctc loss
|
65 |
+
self.loss = F.connectionist_temporal_classification(
|
66 |
+
y_hat, y_true, 0, input_length, label_length
|
67 |
+
)
|
68 |
+
logging.info("ctc loss:" + str(self.loss.data))
|
69 |
+
|
70 |
+
return self.loss
|
71 |
+
|
72 |
+
def log_softmax(self, hs):
|
73 |
+
"""Log_softmax of frame activations.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
hs (list of chainer.Variable | N-dimension array):
|
77 |
+
Input variable from encoder.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
chainer.Variable: A n-dimension float array.
|
81 |
+
|
82 |
+
"""
|
83 |
+
y_hat = self.ctc_lo(F.pad_sequence(hs), n_batch_axes=2)
|
84 |
+
return F.log_softmax(y_hat.reshape(-1, y_hat.shape[-1])).reshape(y_hat.shape)
|
85 |
+
|
86 |
+
|
87 |
+
class WarpCTC(chainer.Chain):
|
88 |
+
"""Chainer implementation of warp-ctc layer.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
odim (int): The output dimension.
|
92 |
+
eproj (int | None): Dimension of input vector from encoder.
|
93 |
+
dropout_rate (float): Dropout rate.
|
94 |
+
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, odim, eprojs, dropout_rate):
|
98 |
+
super(WarpCTC, self).__init__()
|
99 |
+
self.dropout_rate = dropout_rate
|
100 |
+
self.loss = None
|
101 |
+
|
102 |
+
with self.init_scope():
|
103 |
+
self.ctc_lo = L.Linear(eprojs, odim)
|
104 |
+
|
105 |
+
def __call__(self, hs, ys):
|
106 |
+
"""Core function of the Warp-CTC layer.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
hs (iterable of chainer.Variable | N-dimention array):
|
110 |
+
Input variable from encoder.
|
111 |
+
ys (iterable of chainer.Variable | N-dimension array):
|
112 |
+
Input variable of decoder.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
chainer.Variable: A variable holding a scalar value of the CTC loss.
|
116 |
+
|
117 |
+
"""
|
118 |
+
self.loss = None
|
119 |
+
ilens = [x.shape[0] for x in hs]
|
120 |
+
olens = [x.shape[0] for x in ys]
|
121 |
+
|
122 |
+
# zero padding for hs
|
123 |
+
y_hat = self.ctc_lo(
|
124 |
+
F.dropout(F.pad_sequence(hs), ratio=self.dropout_rate), n_batch_axes=2
|
125 |
+
)
|
126 |
+
y_hat = y_hat.transpose(1, 0, 2) # batch x frames x hdim
|
127 |
+
|
128 |
+
# get length info
|
129 |
+
logging.info(self.__class__.__name__ + " input lengths: " + str(ilens))
|
130 |
+
logging.info(self.__class__.__name__ + " output lengths: " + str(olens))
|
131 |
+
|
132 |
+
# get ctc loss
|
133 |
+
from chainer_ctc.warpctc import ctc as warp_ctc
|
134 |
+
|
135 |
+
self.loss = warp_ctc(y_hat, ilens, [cuda.to_cpu(y.data) for y in ys])[0]
|
136 |
+
logging.info("ctc loss:" + str(self.loss.data))
|
137 |
+
|
138 |
+
return self.loss
|
139 |
+
|
140 |
+
def log_softmax(self, hs):
|
141 |
+
"""Log_softmax of frame activations.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
hs (list of chainer.Variable | N-dimension array):
|
145 |
+
Input variable from encoder.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
chainer.Variable: A n-dimension float array.
|
149 |
+
|
150 |
+
"""
|
151 |
+
y_hat = self.ctc_lo(F.pad_sequence(hs), n_batch_axes=2)
|
152 |
+
return F.log_softmax(y_hat.reshape(-1, y_hat.shape[-1])).reshape(y_hat.shape)
|
153 |
+
|
154 |
+
def argmax(self, hs_pad):
|
155 |
+
"""argmax of frame activations
|
156 |
+
|
157 |
+
:param chainer variable hs_pad: 3d tensor (B, Tmax, eprojs)
|
158 |
+
:return: argmax applied 2d tensor (B, Tmax)
|
159 |
+
:rtype: chainer.Variable
|
160 |
+
"""
|
161 |
+
return F.argmax(self.ctc_lo(F.pad_sequence(hs_pad), n_batch_axes=2), axis=-1)
|
162 |
+
|
163 |
+
|
164 |
+
def ctc_for(args, odim):
|
165 |
+
"""Return the CTC layer corresponding to the args.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
args (Namespace): The program arguments.
|
169 |
+
odim (int): The output dimension.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
The CTC module.
|
173 |
+
|
174 |
+
"""
|
175 |
+
ctc_type = args.ctc_type
|
176 |
+
if ctc_type == "builtin":
|
177 |
+
logging.info("Using chainer CTC implementation")
|
178 |
+
ctc = CTC(odim, args.eprojs, args.dropout_rate)
|
179 |
+
elif ctc_type == "warpctc":
|
180 |
+
logging.info("Using warpctc CTC implementation")
|
181 |
+
ctc = WarpCTC(odim, args.eprojs, args.dropout_rate)
|
182 |
+
else:
|
183 |
+
raise ValueError('ctc_type must be "builtin" or "warpctc": {}'.format(ctc_type))
|
184 |
+
return ctc
|
espnet/nets/chainer_backend/deterministic_embed_id.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import six
|
3 |
+
|
4 |
+
import chainer
|
5 |
+
from chainer import cuda
|
6 |
+
from chainer import function_node
|
7 |
+
from chainer.initializers import normal
|
8 |
+
|
9 |
+
# from chainer.functions.connection import embed_id
|
10 |
+
from chainer import link
|
11 |
+
from chainer.utils import type_check
|
12 |
+
from chainer import variable
|
13 |
+
|
14 |
+
"""Deterministic EmbedID link and function
|
15 |
+
|
16 |
+
copied from chainer/links/connection/embed_id.py
|
17 |
+
and chainer/functions/connection/embed_id.py,
|
18 |
+
and modified not to use atomicAdd operation
|
19 |
+
"""
|
20 |
+
|
21 |
+
|
22 |
+
class EmbedIDFunction(function_node.FunctionNode):
|
23 |
+
def __init__(self, ignore_label=None):
|
24 |
+
self.ignore_label = ignore_label
|
25 |
+
self._w_shape = None
|
26 |
+
|
27 |
+
def check_type_forward(self, in_types):
|
28 |
+
type_check.expect(in_types.size() == 2)
|
29 |
+
x_type, w_type = in_types
|
30 |
+
type_check.expect(
|
31 |
+
x_type.dtype.kind == "i",
|
32 |
+
x_type.ndim >= 1,
|
33 |
+
)
|
34 |
+
type_check.expect(w_type.dtype == numpy.float32, w_type.ndim == 2)
|
35 |
+
|
36 |
+
def forward(self, inputs):
|
37 |
+
self.retain_inputs((0,))
|
38 |
+
x, W = inputs
|
39 |
+
self._w_shape = W.shape
|
40 |
+
|
41 |
+
if not type_check.same_types(*inputs):
|
42 |
+
raise ValueError(
|
43 |
+
"numpy and cupy must not be used together\n"
|
44 |
+
"type(W): {0}, type(x): {1}".format(type(W), type(x))
|
45 |
+
)
|
46 |
+
|
47 |
+
xp = cuda.get_array_module(*inputs)
|
48 |
+
if chainer.is_debug():
|
49 |
+
valid_x = xp.logical_and(0 <= x, x < len(W))
|
50 |
+
if self.ignore_label is not None:
|
51 |
+
valid_x = xp.logical_or(valid_x, x == self.ignore_label)
|
52 |
+
if not valid_x.all():
|
53 |
+
raise ValueError(
|
54 |
+
"Each not ignored `x` value need to satisfy" "`0 <= x < len(W)`"
|
55 |
+
)
|
56 |
+
|
57 |
+
if self.ignore_label is not None:
|
58 |
+
mask = x == self.ignore_label
|
59 |
+
return (xp.where(mask[..., None], 0, W[xp.where(mask, 0, x)]),)
|
60 |
+
|
61 |
+
return (W[x],)
|
62 |
+
|
63 |
+
def backward(self, indexes, grad_outputs):
|
64 |
+
inputs = self.get_retained_inputs()
|
65 |
+
gW = EmbedIDGrad(self._w_shape, self.ignore_label).apply(inputs + grad_outputs)[
|
66 |
+
0
|
67 |
+
]
|
68 |
+
return None, gW
|
69 |
+
|
70 |
+
|
71 |
+
class EmbedIDGrad(function_node.FunctionNode):
|
72 |
+
def __init__(self, w_shape, ignore_label=None):
|
73 |
+
self.w_shape = w_shape
|
74 |
+
self.ignore_label = ignore_label
|
75 |
+
self._gy_shape = None
|
76 |
+
|
77 |
+
def forward(self, inputs):
|
78 |
+
self.retain_inputs((0,))
|
79 |
+
xp = cuda.get_array_module(*inputs)
|
80 |
+
x, gy = inputs
|
81 |
+
self._gy_shape = gy.shape
|
82 |
+
gW = xp.zeros(self.w_shape, dtype=gy.dtype)
|
83 |
+
|
84 |
+
if xp is numpy:
|
85 |
+
# It is equivalent to `numpy.add.at(gW, x, gy)` but ufunc.at is
|
86 |
+
# too slow.
|
87 |
+
for ix, igy in six.moves.zip(x.ravel(), gy.reshape(x.size, -1)):
|
88 |
+
if ix == self.ignore_label:
|
89 |
+
continue
|
90 |
+
gW[ix] += igy
|
91 |
+
else:
|
92 |
+
"""
|
93 |
+
# original code based on cuda elementwise method
|
94 |
+
if self.ignore_label is None:
|
95 |
+
cuda.elementwise(
|
96 |
+
'T gy, S x, S n_out', 'raw T gW',
|
97 |
+
'ptrdiff_t w_ind[] = {x, i % n_out};'
|
98 |
+
'atomicAdd(&gW[w_ind], gy)',
|
99 |
+
'embed_id_bwd')(
|
100 |
+
gy, xp.expand_dims(x, -1), gW.shape[1], gW)
|
101 |
+
else:
|
102 |
+
cuda.elementwise(
|
103 |
+
'T gy, S x, S n_out, S ignore', 'raw T gW',
|
104 |
+
'''
|
105 |
+
if (x != ignore) {
|
106 |
+
ptrdiff_t w_ind[] = {x, i % n_out};
|
107 |
+
atomicAdd(&gW[w_ind], gy);
|
108 |
+
}
|
109 |
+
''',
|
110 |
+
'embed_id_bwd_ignore_label')(
|
111 |
+
gy, xp.expand_dims(x, -1), gW.shape[1],
|
112 |
+
self.ignore_label, gW)
|
113 |
+
"""
|
114 |
+
# EmbedID gradient alternative without atomicAdd, which simply
|
115 |
+
# creates a one-hot vector and applies dot product
|
116 |
+
xi = xp.zeros((x.size, len(gW)), dtype=numpy.float32)
|
117 |
+
idx = xp.arange(x.size, dtype=numpy.int32) * len(gW) + x.ravel()
|
118 |
+
xi.ravel()[idx] = 1.0
|
119 |
+
if self.ignore_label is not None:
|
120 |
+
xi[:, self.ignore_label] = 0.0
|
121 |
+
gW = xi.T.dot(gy.reshape(x.size, -1)).astype(gW.dtype, copy=False)
|
122 |
+
|
123 |
+
return (gW,)
|
124 |
+
|
125 |
+
def backward(self, indexes, grads):
|
126 |
+
xp = cuda.get_array_module(*grads)
|
127 |
+
x = self.get_retained_inputs()[0].data
|
128 |
+
ggW = grads[0]
|
129 |
+
|
130 |
+
if self.ignore_label is not None:
|
131 |
+
mask = x == self.ignore_label
|
132 |
+
# To prevent index out of bounds, we need to check if ignore_label
|
133 |
+
# is inside of W.
|
134 |
+
if not (0 <= self.ignore_label < self.w_shape[1]):
|
135 |
+
x = xp.where(mask, 0, x)
|
136 |
+
|
137 |
+
ggy = ggW[x]
|
138 |
+
|
139 |
+
if self.ignore_label is not None:
|
140 |
+
mask, zero, _ = xp.broadcast_arrays(
|
141 |
+
mask[..., None], xp.zeros((), "f"), ggy.data
|
142 |
+
)
|
143 |
+
ggy = chainer.functions.where(mask, zero, ggy)
|
144 |
+
return None, ggy
|
145 |
+
|
146 |
+
|
147 |
+
def embed_id(x, W, ignore_label=None):
|
148 |
+
r"""Efficient linear function for one-hot input.
|
149 |
+
|
150 |
+
This function implements so called *word embeddings*. It takes two
|
151 |
+
arguments: a set of IDs (words) ``x`` in :math:`B` dimensional integer
|
152 |
+
vector, and a set of all ID (word) embeddings ``W`` in :math:`V \\times d`
|
153 |
+
float32 matrix. It outputs :math:`B \\times d` matrix whose ``i``-th
|
154 |
+
column is the ``x[i]``-th column of ``W``.
|
155 |
+
This function is only differentiable on the input ``W``.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
x (chainer.Variable | np.ndarray): Batch vectors of IDs. Each
|
159 |
+
element must be signed integer.
|
160 |
+
W (chainer.Variable | np.ndarray): Distributed representation
|
161 |
+
of each ID (a.k.a. word embeddings).
|
162 |
+
ignore_label (int): If ignore_label is an int value, i-th column
|
163 |
+
of return value is filled with 0.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
chainer.Variable: Embedded variable.
|
167 |
+
|
168 |
+
|
169 |
+
.. rubric:: :class:`~chainer.links.EmbedID`
|
170 |
+
|
171 |
+
Examples:
|
172 |
+
|
173 |
+
>>> x = np.array([2, 1]).astype('i')
|
174 |
+
>>> x
|
175 |
+
array([2, 1], dtype=int32)
|
176 |
+
>>> W = np.array([[0, 0, 0],
|
177 |
+
... [1, 1, 1],
|
178 |
+
... [2, 2, 2]]).astype('f')
|
179 |
+
>>> W
|
180 |
+
array([[ 0., 0., 0.],
|
181 |
+
[ 1., 1., 1.],
|
182 |
+
[ 2., 2., 2.]], dtype=float32)
|
183 |
+
>>> F.embed_id(x, W).data
|
184 |
+
array([[ 2., 2., 2.],
|
185 |
+
[ 1., 1., 1.]], dtype=float32)
|
186 |
+
>>> F.embed_id(x, W, ignore_label=1).data
|
187 |
+
array([[ 2., 2., 2.],
|
188 |
+
[ 0., 0., 0.]], dtype=float32)
|
189 |
+
|
190 |
+
"""
|
191 |
+
return EmbedIDFunction(ignore_label=ignore_label).apply((x, W))[0]
|
192 |
+
|
193 |
+
|
194 |
+
class EmbedID(link.Link):
|
195 |
+
"""Efficient linear layer for one-hot input.
|
196 |
+
|
197 |
+
This is a link that wraps the :func:`~chainer.functions.embed_id` function.
|
198 |
+
This link holds the ID (word) embedding matrix ``W`` as a parameter.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
in_size (int): Number of different identifiers (a.k.a. vocabulary size).
|
202 |
+
out_size (int): Output dimension.
|
203 |
+
initialW (Initializer): Initializer to initialize the weight.
|
204 |
+
ignore_label (int): If `ignore_label` is an int value, i-th column of
|
205 |
+
return value is filled with 0.
|
206 |
+
|
207 |
+
.. rubric:: :func:`~chainer.functions.embed_id`
|
208 |
+
|
209 |
+
Attributes:
|
210 |
+
W (~chainer.Variable): Embedding parameter matrix.
|
211 |
+
|
212 |
+
Examples:
|
213 |
+
|
214 |
+
>>> W = np.array([[0, 0, 0],
|
215 |
+
... [1, 1, 1],
|
216 |
+
... [2, 2, 2]]).astype('f')
|
217 |
+
>>> W
|
218 |
+
array([[ 0., 0., 0.],
|
219 |
+
[ 1., 1., 1.],
|
220 |
+
[ 2., 2., 2.]], dtype=float32)
|
221 |
+
>>> l = L.EmbedID(W.shape[0], W.shape[1], initialW=W)
|
222 |
+
>>> x = np.array([2, 1]).astype('i')
|
223 |
+
>>> x
|
224 |
+
array([2, 1], dtype=int32)
|
225 |
+
>>> y = l(x)
|
226 |
+
>>> y.data
|
227 |
+
array([[ 2., 2., 2.],
|
228 |
+
[ 1., 1., 1.]], dtype=float32)
|
229 |
+
|
230 |
+
"""
|
231 |
+
|
232 |
+
ignore_label = None
|
233 |
+
|
234 |
+
def __init__(self, in_size, out_size, initialW=None, ignore_label=None):
|
235 |
+
super(EmbedID, self).__init__()
|
236 |
+
self.ignore_label = ignore_label
|
237 |
+
|
238 |
+
with self.init_scope():
|
239 |
+
if initialW is None:
|
240 |
+
initialW = normal.Normal(1.0)
|
241 |
+
self.W = variable.Parameter(initialW, (in_size, out_size))
|
242 |
+
|
243 |
+
def __call__(self, x):
|
244 |
+
"""Extracts the word embedding of given IDs.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
x (chainer.Variable): Batch vectors of IDs.
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
chainer.Variable: Batch of corresponding embeddings.
|
251 |
+
|
252 |
+
"""
|
253 |
+
return embed_id(x, self.W, ignore_label=self.ignore_label)
|