pfluo commited on
Commit
6eb615a
1 Parent(s): 1a97bfb

update export models and scripts (#6)

Browse files

- Use the latest pruned_transducer_stateless7_streaming script in icefall to export the model and update the export scripts in the repo (427c4b37e1885d621595389453cad0717fbad7bd)

exp/cpu_jit.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5fe5d637dc8f8a8717c4cad0200909271ef4b30248542896f0257a726c289631
3
- size 379190270
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b41fa49583b69438105016d68672604e0359498925dd0c7b5965184a445cc8c
3
+ size 379196926
exp/decoder_jit_trace.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f6298c4c09b88b0020b4fab955ff9dea41a00bbf931db2007f6a70b412346d33
3
- size 12831269
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83cc6f5cbf4e3e7a518546c2ee4e8d9c17d479ceebcac3648a031985d44ec89d
3
+ size 12831333
exp/encoder_jit_trace.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d748e243b4dd566196c9488346ff3863cad04d6ade6f3336c81491a5b5f0f9f0
3
- size 330595617
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83e2ec17083607da4085d9653c7d15c908e28fd50488e62c6f24e11e795a90a4
3
+ size 330607841
exp/export-stateless7-streaming-zh.sh CHANGED
@@ -2,8 +2,8 @@
2
 
3
  . path.sh
4
 
5
- ./pruned_transducer_stateless7_streaming/export-zh.py \
6
- --lang-dir ./k2fsa-zipformer-chinese-english-mixed/data/lang_char_bpe \
7
  --use-averaged-model 0 \
8
  --epoch 99 \
9
  --avg 1 \
 
2
 
3
  . path.sh
4
 
5
+ ./pruned_transducer_stateless7_streaming/export.py \
6
+ --tokens ./k2fsa-zipformer-chinese-english-mixed/data/lang_char_bpe/tokens.txt \
7
  --use-averaged-model 0 \
8
  --epoch 99 \
9
  --avg 1 \
exp/jit_trace_export-zh.py DELETED
@@ -1,323 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- """
4
- Usage:
5
- ./pruned_transducer_stateless7_streaming/jit_trace_export-zh.py \
6
- --exp-dir $dir/exp \
7
- --exp-dir ./pruned_transducer_stateless7_streaming/exp \
8
- --lang-dir ./data/lang_char_bpe \
9
- --epoch 99 \
10
- --avg 1 \
11
- --use-averaged-model 0 \
12
- \
13
- --decode-chunk-len 32 \
14
- --num-encoder-layers "2,4,3,2,4" \
15
- --feedforward-dims "1024,1024,1536,1536,1024" \
16
- --nhead "8,8,8,8,8" \
17
- --encoder-dims "384,384,384,384,384" \
18
- --attention-dims "192,192,192,192,192" \
19
- --encoder-unmasked-dims "256,256,256,256,256" \
20
- --zipformer-downsampling-factors "1,2,4,8,2" \
21
- --cnn-module-kernels "31,31,31,31,31" \
22
- --decoder-dim 512 \
23
- --joiner-dim 512
24
- """
25
-
26
- import argparse
27
- import logging
28
- from pathlib import Path
29
-
30
- import sentencepiece as spm
31
- import torch
32
- from scaling_converter import convert_scaled_to_non_scaled
33
- from train import add_model_arguments, get_params, get_transducer_model
34
- from icefall.lexicon import Lexicon
35
-
36
- from icefall.checkpoint import (
37
- average_checkpoints,
38
- average_checkpoints_with_averaged_model,
39
- find_checkpoints,
40
- load_checkpoint,
41
- )
42
- from icefall.utils import AttributeDict, str2bool
43
-
44
-
45
- def get_parser():
46
- parser = argparse.ArgumentParser(
47
- formatter_class=argparse.ArgumentDefaultsHelpFormatter
48
- )
49
-
50
- parser.add_argument(
51
- "--epoch",
52
- type=int,
53
- default=28,
54
- help="""It specifies the checkpoint to use for averaging.
55
- Note: Epoch counts from 0.
56
- You can specify --avg to use more checkpoints for model averaging.""",
57
- )
58
-
59
- parser.add_argument(
60
- "--iter",
61
- type=int,
62
- default=0,
63
- help="""If positive, --epoch is ignored and it
64
- will use the checkpoint exp_dir/checkpoint-iter.pt.
65
- You can specify --avg to use more checkpoints for model averaging.
66
- """,
67
- )
68
-
69
- parser.add_argument(
70
- "--avg",
71
- type=int,
72
- default=15,
73
- help="Number of checkpoints to average. Automatically select "
74
- "consecutive checkpoints before the checkpoint specified by "
75
- "'--epoch' and '--iter'",
76
- )
77
-
78
- parser.add_argument(
79
- "--exp-dir",
80
- type=str,
81
- default="pruned_transducer_stateless2/exp",
82
- help="""It specifies the directory where all training related
83
- files, e.g., checkpoints, log, etc, are saved
84
- """,
85
- )
86
-
87
- parser.add_argument(
88
- "--lang-dir",
89
- type=str,
90
- default="data/lang_char",
91
- help="The lang dir",
92
- )
93
-
94
- parser.add_argument(
95
- "--context-size",
96
- type=int,
97
- default=2,
98
- help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
99
- )
100
-
101
- parser.add_argument(
102
- "--use-averaged-model",
103
- type=str2bool,
104
- default=True,
105
- help="Whether to load averaged model. Currently it only supports "
106
- "using --epoch. If True, it would decode with the averaged model "
107
- "over the epoch range from `epoch-avg` (excluded) to `epoch`."
108
- "Actually only the models with epoch number of `epoch-avg` and "
109
- "`epoch` are loaded for averaging. ",
110
- )
111
-
112
- add_model_arguments(parser)
113
-
114
- return parser
115
-
116
-
117
- def export_encoder_model_jit_trace(
118
- encoder_model: torch.nn.Module,
119
- encoder_filename: str,
120
- params: AttributeDict,
121
- ) -> None:
122
- """Export the given encoder model with torch.jit.trace()
123
-
124
- Note: The warmup argument is fixed to 1.
125
-
126
- Args:
127
- encoder_model:
128
- The input encoder model
129
- encoder_filename:
130
- The filename to save the exported model.
131
- """
132
- decode_chunk_len = params.decode_chunk_len # before subsampling
133
- pad_length = 7
134
- s = f"decode_chunk_len: {decode_chunk_len}"
135
- logging.info(s)
136
- assert encoder_model.decode_chunk_size == decode_chunk_len // 2, (
137
- encoder_model.decode_chunk_size,
138
- decode_chunk_len,
139
- )
140
-
141
- T = decode_chunk_len + pad_length
142
-
143
- x = torch.zeros(1, T, 80, dtype=torch.float32)
144
- x_lens = torch.full((1,), T, dtype=torch.int32)
145
- states = encoder_model.get_init_state(device=x.device)
146
-
147
- encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
148
- traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
149
- traced_model.save(encoder_filename)
150
- logging.info(f"Saved to {encoder_filename}")
151
-
152
-
153
- def export_decoder_model_jit_trace(
154
- decoder_model: torch.nn.Module,
155
- decoder_filename: str,
156
- ) -> None:
157
- """Export the given decoder model with torch.jit.trace()
158
-
159
- Note: The argument need_pad is fixed to False.
160
-
161
- Args:
162
- decoder_model:
163
- The input decoder model
164
- decoder_filename:
165
- The filename to save the exported model.
166
- """
167
- y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
168
- need_pad = torch.tensor([False])
169
-
170
- traced_model = torch.jit.trace(decoder_model, (y, need_pad))
171
- traced_model.save(decoder_filename)
172
- logging.info(f"Saved to {decoder_filename}")
173
-
174
-
175
- def export_joiner_model_jit_trace(
176
- joiner_model: torch.nn.Module,
177
- joiner_filename: str,
178
- ) -> None:
179
- """Export the given joiner model with torch.jit.trace()
180
-
181
- Note: The argument project_input is fixed to True. A user should not
182
- project the encoder_out/decoder_out by himself/herself. The exported joiner
183
- will do that for the user.
184
-
185
- Args:
186
- joiner_model:
187
- The input joiner model
188
- joiner_filename:
189
- The filename to save the exported model.
190
-
191
- """
192
- encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
193
- decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
194
- encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
195
- decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
196
-
197
- traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
198
- traced_model.save(joiner_filename)
199
- logging.info(f"Saved to {joiner_filename}")
200
-
201
-
202
- @torch.no_grad()
203
- def main():
204
- args = get_parser().parse_args()
205
- args.exp_dir = Path(args.exp_dir)
206
-
207
- params = get_params()
208
- params.update(vars(args))
209
-
210
- device = torch.device("cpu")
211
-
212
- logging.info(f"device: {device}")
213
-
214
- lexicon = Lexicon(params.lang_dir)
215
- params.blank_id = 0
216
- params.vocab_size = max(lexicon.tokens) + 1
217
-
218
- logging.info(params)
219
-
220
- logging.info("About to create model")
221
- model = get_transducer_model(params)
222
-
223
- if not params.use_averaged_model:
224
- if params.iter > 0:
225
- filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
226
- : params.avg
227
- ]
228
- if len(filenames) == 0:
229
- raise ValueError(
230
- f"No checkpoints found for"
231
- f" --iter {params.iter}, --avg {params.avg}"
232
- )
233
- elif len(filenames) < params.avg:
234
- raise ValueError(
235
- f"Not enough checkpoints ({len(filenames)}) found for"
236
- f" --iter {params.iter}, --avg {params.avg}"
237
- )
238
- logging.info(f"averaging {filenames}")
239
- model.to(device)
240
- model.load_state_dict(average_checkpoints(filenames, device=device))
241
- elif params.avg == 1:
242
- load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
243
- else:
244
- start = params.epoch - params.avg + 1
245
- filenames = []
246
- for i in range(start, params.epoch + 1):
247
- if i >= 1:
248
- filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
249
- logging.info(f"averaging {filenames}")
250
- model.to(device)
251
- model.load_state_dict(average_checkpoints(filenames, device=device))
252
- else:
253
- if params.iter > 0:
254
- filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
255
- : params.avg + 1
256
- ]
257
- if len(filenames) == 0:
258
- raise ValueError(
259
- f"No checkpoints found for"
260
- f" --iter {params.iter}, --avg {params.avg}"
261
- )
262
- elif len(filenames) < params.avg + 1:
263
- raise ValueError(
264
- f"Not enough checkpoints ({len(filenames)}) found for"
265
- f" --iter {params.iter}, --avg {params.avg}"
266
- )
267
- filename_start = filenames[-1]
268
- filename_end = filenames[0]
269
- logging.info(
270
- "Calculating the averaged model over iteration checkpoints"
271
- f" from {filename_start} (excluded) to {filename_end}"
272
- )
273
- model.to(device)
274
- model.load_state_dict(
275
- average_checkpoints_with_averaged_model(
276
- filename_start=filename_start,
277
- filename_end=filename_end,
278
- device=device,
279
- )
280
- )
281
- else:
282
- assert params.avg > 0, params.avg
283
- start = params.epoch - params.avg
284
- assert start >= 1, start
285
- filename_start = f"{params.exp_dir}/epoch-{start}.pt"
286
- filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
287
- logging.info(
288
- f"Calculating the averaged model over epoch range from "
289
- f"{start} (excluded) to {params.epoch}"
290
- )
291
- model.to(device)
292
- model.load_state_dict(
293
- average_checkpoints_with_averaged_model(
294
- filename_start=filename_start,
295
- filename_end=filename_end,
296
- device=device,
297
- )
298
- )
299
-
300
- model.to("cpu")
301
- model.eval()
302
-
303
- convert_scaled_to_non_scaled(model, inplace=True)
304
- logging.info("Using torch.jit.trace()")
305
-
306
- logging.info("Exporting encoder")
307
- encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
308
- export_encoder_model_jit_trace(model.encoder, encoder_filename, params)
309
-
310
- logging.info("Exporting decoder")
311
- decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
312
- export_decoder_model_jit_trace(model.decoder, decoder_filename)
313
-
314
- logging.info("Exporting joiner")
315
- joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
316
- export_joiner_model_jit_trace(model.joiner, joiner_filename)
317
-
318
-
319
- if __name__ == "__main__":
320
- formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
321
-
322
- logging.basicConfig(format=formatter, level=logging.INFO)
323
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
exp/jit_trace_export-zh.sh CHANGED
@@ -15,9 +15,9 @@ if [ ! -f $dir/exp/epoch-99.pt ]; then
15
  popd
16
  fi
17
 
18
- ./pruned_transducer_stateless7_streaming/jit_trace_export-zh.py \
19
  --exp-dir $dir/exp \
20
- --lang-dir $dir/data/lang_char_bpe \
21
  --epoch 99 \
22
  --avg 1 \
23
  --use-averaged-model 0 \
 
15
  popd
16
  fi
17
 
18
+ ./pruned_transducer_stateless7_streaming/jit_trace_export.py \
19
  --exp-dir $dir/exp \
20
+ --bpe-model $dir/data/lang_char_bpe/bpe.model \
21
  --epoch 99 \
22
  --avg 1 \
23
  --use-averaged-model 0 \
exp/joiner_jit_trace.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1765e66b163ba43b0d1825e8ecb3f2edc23170ba5104b7e3827bf0bad318b12e
3
- size 14680987
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdd58624fd2df70b5583c684c6705cd28b27d73bc15ef95484356315d1579043
3
+ size 14681115