supawichwac commited on
Commit
55f3766
1 Parent(s): 96bad36

Saving train state of step 50

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. .ipynb_checkpoints/run_distillation-checkpoint.py +1693 -0
  3. .ipynb_checkpoints/setup-checkpoint.py +52 -0
  4. Makefile +9 -0
  5. README.md +563 -0
  6. added_tokens.json +1611 -0
  7. checkpoint-50-epoch-0/model.safetensors +3 -0
  8. checkpoint-50-epoch-0/model_1.safetensors +3 -0
  9. checkpoint-50-epoch-0/optimizer.bin +3 -0
  10. checkpoint-50-epoch-0/random_states_0.pkl +3 -0
  11. checkpoint-50-epoch-0/scheduler.bin +3 -0
  12. config.json +50 -0
  13. create_student_model.py +215 -0
  14. distil-large-v3-init/added_tokens.json +1611 -0
  15. distil-large-v3-init/config.json +50 -0
  16. distil-large-v3-init/generation_config.json +255 -0
  17. distil-large-v3-init/merges.txt +0 -0
  18. distil-large-v3-init/model.safetensors +3 -0
  19. distil-large-v3-init/normalizer.json +1742 -0
  20. distil-large-v3-init/preprocessor_config.json +14 -0
  21. distil-large-v3-init/special_tokens_map.json +139 -0
  22. distil-large-v3-init/tokenizer_config.json +0 -0
  23. distil-large-v3-init/vocab.json +0 -0
  24. distil-whisper/events.out.tfevents.1714645175.server02.624510.0 +3 -0
  25. distil-whisper/events.out.tfevents.1715051424.server02.1325731.0 +3 -0
  26. distil-whisper/events.out.tfevents.1715051868.server02.1327224.0 +3 -0
  27. distil_whisper.egg-info/PKG-INFO +580 -0
  28. distil_whisper.egg-info/SOURCES.txt +8 -0
  29. distil_whisper.egg-info/dependency_links.txt +1 -0
  30. distil_whisper.egg-info/requires.txt +12 -0
  31. distil_whisper.egg-info/top_level.txt +1 -0
  32. flax/LICENSE +201 -0
  33. flax/Makefile +9 -0
  34. flax/README.md +293 -0
  35. flax/conversion_scripts/run_convert_distilled_train_state_to_hf.sh +8 -0
  36. flax/convert_train_state_to_hf.py +327 -0
  37. flax/create_student_model.py +226 -0
  38. flax/distil_whisper/__init__.py +21 -0
  39. flax/distil_whisper/layers.py +1338 -0
  40. flax/distil_whisper/modeling_flax_whisper.py +2135 -0
  41. flax/distil_whisper/partitioner.py +965 -0
  42. flax/distil_whisper/pipeline.py +527 -0
  43. flax/distil_whisper/train_state.py +118 -0
  44. flax/distillation_scripts/run_32_2_pt.sh +38 -0
  45. flax/distillation_scripts/run_bs_sweep.yaml +67 -0
  46. flax/distillation_scripts/run_dataset_sweep.yaml +77 -0
  47. flax/distillation_scripts/run_decoder_sweep.yaml +72 -0
  48. flax/distillation_scripts/run_distillation_12_2_timestamped.sh +42 -0
  49. flax/distillation_scripts/run_distillation_15s_context.sh +43 -0
  50. flax/distillation_scripts/run_distillation_16_2.sh +41 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ wandb
.ipynb_checkpoints/run_distillation-checkpoint.py ADDED
@@ -0,0 +1,1693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training the Whisper model for sequence to sequence speech recognition via teacher-student distillation.
18
+ """
19
+ # You can also adapt this script for your own distillation tasks. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import re
24
+ import shutil
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from functools import partial
29
+ from pathlib import Path
30
+ from typing import Any, Dict, List, Optional, Union
31
+
32
+ import datasets
33
+ import evaluate
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn as nn
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from datasets import (
41
+ DatasetDict,
42
+ IterableDataset,
43
+ IterableDatasetDict,
44
+ concatenate_datasets,
45
+ interleave_datasets,
46
+ load_dataset,
47
+ )
48
+ from huggingface_hub import create_repo, get_full_repo_name, upload_folder
49
+ from torch.utils.data import DataLoader
50
+ from tqdm import tqdm
51
+ from transformers import (
52
+ AddedToken,
53
+ HfArgumentParser,
54
+ Seq2SeqTrainingArguments,
55
+ WhisperConfig,
56
+ WhisperFeatureExtractor,
57
+ WhisperForConditionalGeneration,
58
+ WhisperProcessor,
59
+ WhisperTokenizerFast,
60
+ get_scheduler,
61
+ set_seed,
62
+ )
63
+ from transformers.modeling_outputs import BaseModelOutput
64
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
65
+ from transformers.utils import check_min_version
66
+ from transformers.utils.versions import require_version
67
+
68
+
69
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
70
+ check_min_version("4.34.0.dev0")
71
+
72
+ require_version("datasets>=2.14.6", "To fix: `pip install --upgrade datasets`")
73
+
74
+ logger = get_logger(__name__)
75
+
76
+
77
+ @dataclass
78
+ class ModelArguments:
79
+ """
80
+ Arguments pertaining to which model/config/tokenizer we are going to distill from.
81
+ """
82
+
83
+ model_name_or_path: str = field(
84
+ metadata={"help": "Path to pretrained Whisper model or model identifier from huggingface.co/models"}
85
+ )
86
+ teacher_model_name_or_path: str = field(
87
+ metadata={"help": "Path to pretrained teacher model or model identifier from huggingface.co/models"}
88
+ )
89
+ config_name: Optional[str] = field(
90
+ default=None,
91
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
92
+ )
93
+ tokenizer_name: Optional[str] = field(
94
+ default=None,
95
+ metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
96
+ )
97
+ feature_extractor_name: Optional[str] = field(
98
+ default=None,
99
+ metadata={"help": "feature extractor name or path if not the same as model_name"},
100
+ )
101
+ cache_dir: Optional[str] = field(
102
+ default=None,
103
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
104
+ )
105
+ use_fast_tokenizer: bool = field(
106
+ default=True,
107
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
108
+ )
109
+ model_revision: str = field(
110
+ default="main",
111
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
112
+ )
113
+ subfolder: str = field(
114
+ default="",
115
+ metadata={
116
+ "help": "In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can"
117
+ "specify the folder name here."
118
+ },
119
+ )
120
+ token: str = field(
121
+ default=None,
122
+ metadata={
123
+ "help": (
124
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
125
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
126
+ )
127
+ },
128
+ )
129
+ attn_implementation: Optional[str] = field(
130
+ default=None,
131
+ metadata={
132
+ "help": (
133
+ "Which attention implementation to use in the encoder and decoder attention layers. Can be one of:\n"
134
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
135
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
136
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
137
+ )
138
+ },
139
+ )
140
+
141
+ def __post_init__(self):
142
+ if self.attn_implementation not in [None, "eager", "sdpa", "flash_attention_2"]:
143
+ raise ValueError(
144
+ f"Got `--attn_implementation={self.attn_implementation}`, which is an invalid attention type. Should be one of:\n"
145
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
146
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
147
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
148
+ )
149
+
150
+
151
+ @dataclass
152
+ class DataTrainingArguments:
153
+ """
154
+ Arguments pertaining to what data we are going to input our model for training and eval.
155
+ """
156
+
157
+ train_dataset_name: str = field(
158
+ default=None,
159
+ metadata={
160
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
161
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load LibriSpeech "
162
+ "and Common Voice, set `train_dataset_name='librispeech_asr+common_voice'`."
163
+ },
164
+ )
165
+ train_dataset_config_name: Optional[str] = field(
166
+ default=None,
167
+ metadata={
168
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
169
+ "multiple datasets by separating dataset configs by a '+' symbol. Note that the order of the configs should "
170
+ "match the order of the datasets."
171
+ },
172
+ )
173
+ train_dataset_samples: str = field(
174
+ default=None,
175
+ metadata={
176
+ "help": "Number of samples in each dataset when loading multiple datasets with streaming mode. "
177
+ "Not required when using one dataset or non-streaming mode. The sample values provide the sampling "
178
+ "probability for each dataset. Setting them equal to the number of sample values ensures that every "
179
+ "sample from every dataset is used once per epoch."
180
+ },
181
+ )
182
+ eval_dataset_name: str = field(
183
+ default=None,
184
+ metadata={
185
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training "
186
+ "dataset name if unspecified. Load multiple evaluation datasets by separating dataset "
187
+ "ids by a '+' symbol."
188
+ },
189
+ )
190
+ eval_dataset_config_name: Optional[str] = field(
191
+ default=None,
192
+ metadata={
193
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the "
194
+ "training dataset config name if unspecified."
195
+ },
196
+ )
197
+ dataset_cache_dir: Optional[str] = field(
198
+ default=None,
199
+ metadata={"help": "Path to cache directory for saving and loading datasets"},
200
+ )
201
+ overwrite_cache: bool = field(
202
+ default=False,
203
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
204
+ )
205
+ preprocessing_num_workers: Optional[int] = field(
206
+ default=None,
207
+ metadata={"help": "The number of processes to use for the preprocessing if using non-streaming mode."},
208
+ )
209
+ preprocessing_batch_size: Optional[int] = field(
210
+ default=256,
211
+ metadata={"help": "Number of examples per batch provided to the `prepare_dataset` function."},
212
+ )
213
+ max_train_samples: Optional[int] = field(
214
+ default=None,
215
+ metadata={
216
+ "help": (
217
+ "For debugging purposes or quicker training, truncate the number of training examples to this value if set."
218
+ )
219
+ },
220
+ )
221
+ max_eval_samples: Optional[int] = field(
222
+ default=None,
223
+ metadata={
224
+ "help": (
225
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this value if set."
226
+ )
227
+ },
228
+ )
229
+ audio_column_name: str = field(
230
+ default="audio",
231
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
232
+ )
233
+ text_column_name: str = field(
234
+ default=None,
235
+ metadata={"help": "The name of the dataset column containing the text data in the training set."},
236
+ )
237
+ eval_text_column_name: str = field(
238
+ default="text",
239
+ metadata={"help": ("The name of the dataset column containing the text data in the evaluation set.")},
240
+ )
241
+ max_duration_in_seconds: float = field(
242
+ default=30.0,
243
+ metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
244
+ )
245
+ min_duration_in_seconds: float = field(
246
+ default=0.0,
247
+ metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
248
+ )
249
+ max_label_length: int = field(
250
+ default=448,
251
+ metadata={"help": "Truncate transcriptions that are longer `max_label_length` tokens."},
252
+ )
253
+ pad_target_to_multiple_of: Optional[int] = field(
254
+ default=None,
255
+ metadata={
256
+ "help": (
257
+ "If set will pad the target sequence to a multiple of the provided"
258
+ " value. This is important to avoid triggering recompilations on TPU."
259
+ " If unspecified, will default to padding the targets to max length."
260
+ )
261
+ },
262
+ )
263
+ preprocessing_only: bool = field(
264
+ default=False,
265
+ metadata={
266
+ "help": (
267
+ "Whether to only do data preprocessing and skip training. This is"
268
+ " especially useful when data preprocessing errors out in distributed"
269
+ " training due to timeout. In this case, one should run the"
270
+ " preprocessing in a non-distributed setup with"
271
+ " `preprocessing_only=True` so that the cached datasets can"
272
+ " consequently be loaded in distributed training"
273
+ )
274
+ },
275
+ )
276
+ train_split_name: str = field(
277
+ default="train",
278
+ metadata={
279
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
280
+ },
281
+ )
282
+ eval_split_name: str = field(
283
+ default="validation",
284
+ metadata={
285
+ "help": (
286
+ "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
287
+ )
288
+ },
289
+ )
290
+ streaming: bool = field(
291
+ default=True,
292
+ metadata={"help": "Whether to use Datasets' streaming mode to load and pre-process the data."},
293
+ )
294
+ wer_threshold: float = field(
295
+ default=None,
296
+ metadata={
297
+ "help": "Filter training data with Whisper transcriptions that have greater than `wer_threshold` "
298
+ "WER with the normalised transcriptions. This only takes effect if training on pseudo-labels targets."
299
+ "If `--use_pseudo_labels=False`, then no WER filtering is performed, since we train directly on the text"
300
+ "transcriptions."
301
+ },
302
+ )
303
+ use_pseudo_labels: bool = field(
304
+ default=True,
305
+ metadata={
306
+ "help": "Whether or not to use pseudo-label transcriptions as the targets. If True, the pseudo-labels "
307
+ "must be in the dataset column `whisper_transcript` from the previous pseudo-labelling step. This is "
308
+ "not currently yet configurable."
309
+ },
310
+ )
311
+ timestamp_probability: float = field(
312
+ default=0.2, metadata={"help": "Probability for training on timestamped tokens if the data contains it."}
313
+ )
314
+ condition_on_prev_probability: float = field(
315
+ default=0.2, metadata={"help": "Probability for conditioning on the previous text example."}
316
+ )
317
+ return_timestamps: bool = field(
318
+ default=False, metadata={"help": "Whether or not to predict timestamps in the generation step."}
319
+ )
320
+ language: str = field(
321
+ default=None,
322
+ metadata={
323
+ "help": (
324
+ "Language for multilingual distillation. This argument should be set for multilingual distillation "
325
+ "only. For English speech recognition, it should be left as `None`."
326
+ )
327
+ },
328
+ )
329
+ task: str = field(
330
+ default="transcribe",
331
+ metadata={
332
+ "help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."
333
+ "This argument should be set for multilingual distillation only. For English speech recognition, it should be left as `None`."
334
+ },
335
+ )
336
+ wandb_project: str = field(
337
+ default="distil-whisper",
338
+ metadata={"help": "The name of the wandb project."},
339
+ )
340
+
341
+
342
+ @dataclass
343
+ class DistillationTrainingArguments(Seq2SeqTrainingArguments):
344
+ freeze_encoder: Optional[bool] = field(
345
+ default=False,
346
+ metadata={
347
+ "help": (
348
+ "Whether to freeze the entire encoder model. Only recommended when the entire encoder has been "
349
+ "copied from the teacher model."
350
+ )
351
+ },
352
+ )
353
+ freeze_embed_positions: Optional[bool] = field(
354
+ default=False,
355
+ metadata={"help": "Whether to freeze the decoder embedding positions."},
356
+ )
357
+ temperature: Optional[float] = field(
358
+ default=2.0, metadata={"help": "Temperature to anneal the logits when computing the softmax."}
359
+ )
360
+ kl_weight: Optional[float] = field(
361
+ default=1.0,
362
+ metadata={
363
+ "help": (
364
+ "Weighting assigned to the MSE loss in the KD formulation. MSE loss is "
365
+ "computed between the teacher-student hidden states and attentions."
366
+ )
367
+ },
368
+ )
369
+ dtype: Optional[str] = field(
370
+ default="float32",
371
+ metadata={
372
+ "help": (
373
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
374
+ "`float16` or `bfloat16` (both half-precision)."
375
+ )
376
+ },
377
+ )
378
+
379
+
380
+ @dataclass
381
+ class DataCollatorSpeechSeq2SeqWithPadding:
382
+ """
383
+ Data collator that will dynamically pad the inputs received.
384
+ Args:
385
+ processor ([`Wav2Vec2Processor`])
386
+ The processor used for proccessing the data.
387
+ decoder_start_token_id (:obj: `int`)
388
+ The start-of-sequence token id of the decoder.
389
+ decoder_prev_token_id (:obj: `int`)
390
+ The start-of-prompt token id of the decoder
391
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
392
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
393
+ among:
394
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
395
+ sequence if provided).
396
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
397
+ maximum acceptable input length for the model if that argument is not provided.
398
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
399
+ different lengths).
400
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
401
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
402
+ See above for details.
403
+ max_target_length (:obj:`int`, `optional`):
404
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
405
+ """
406
+
407
+ processor: Any
408
+ decoder_start_token_id: int
409
+ decoder_prev_token_id: int
410
+ input_padding: Union[bool, str] = "max_length"
411
+ target_padding: Union[bool, str] = "max_length"
412
+ max_target_length: Optional[int] = None
413
+
414
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
415
+ # split inputs and labels since they have to be of different lengths and need
416
+ # different padding methods
417
+
418
+ # dataloader returns a list of features which we convert to a dict
419
+ input_features = {"input_features": [feature["input_features"] for feature in features]}
420
+ label_features = {"input_ids": [feature["labels"] for feature in features]}
421
+
422
+ # reformat list to dict and set to pytorch format
423
+ batch = self.processor.feature_extractor.pad(
424
+ input_features,
425
+ padding=self.input_padding,
426
+ return_tensors="pt",
427
+ )
428
+
429
+ labels_batch = self.processor.tokenizer.pad(
430
+ label_features,
431
+ max_length=self.max_target_length,
432
+ padding=self.target_padding,
433
+ return_tensors="pt",
434
+ )
435
+
436
+ # shift labels to the right to get decoder input ids
437
+ labels = labels_batch["input_ids"]
438
+ decoder_input_ids = labels[:, :-1]
439
+ labels = labels[:, 1:]
440
+ labels_mask = labels_batch.attention_mask[:, 1:]
441
+
442
+ # replace padding with -100 to ignore correctly when computing the loss
443
+ labels = labels.masked_fill(labels_mask.ne(1), -100)
444
+
445
+ # replace initial prompt tokens with -100 to ignore correctly when computing the loss
446
+ bos_index = torch.argmax((labels == self.decoder_start_token_id).long(), dim=1)
447
+ bos_index = torch.where(bos_index > 0, bos_index + 1, bos_index)
448
+ prompt_mask = torch.arange(labels.shape[1]) < bos_index[:, None]
449
+ labels = torch.where(prompt_mask, -100, labels)
450
+
451
+ batch["labels"] = labels
452
+ batch["decoder_input_ids"] = decoder_input_ids
453
+
454
+ return batch
455
+
456
+
457
+ def log_metric(
458
+ accelerator,
459
+ metrics: Dict,
460
+ train_time: float,
461
+ step: int,
462
+ epoch: int,
463
+ learning_rate: float = None,
464
+ prefix: str = "train",
465
+ ):
466
+ """Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
467
+ log_metrics = {}
468
+ for k, v in metrics.items():
469
+ log_metrics[f"{prefix}/{k}"] = v
470
+ log_metrics[f"{prefix}/time"] = train_time
471
+ log_metrics[f"{prefix}/epoch"] = epoch
472
+ if learning_rate is not None:
473
+ log_metrics[f"{prefix}/learning_rate"] = learning_rate
474
+ accelerator.log(log_metrics, step=step)
475
+
476
+
477
+ def log_pred(
478
+ accelerator,
479
+ pred_str: List[str],
480
+ label_str: List[str],
481
+ norm_pred_str: List[str],
482
+ norm_label_str: List[str],
483
+ step: int,
484
+ prefix: str = "eval",
485
+ num_lines: int = 200000,
486
+ ):
487
+ """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
488
+ if accelerator.is_main_process:
489
+ wandb_tracker = accelerator.get_tracker("wandb")
490
+ # pretty name for current step: step 50000 -> step 50k
491
+ cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
492
+ prefix_pretty = prefix.replace("/", "-")
493
+
494
+ # convert str data to a wandb compatible format
495
+ str_data = [[label_str[i], pred_str[i], norm_label_str[i], norm_pred_str[i]] for i in range(len(pred_str))]
496
+ # log as a table with the appropriate headers
497
+ wandb_tracker.log_table(
498
+ table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
499
+ columns=["Target", "Pred", "Norm Target", "Norm Pred"],
500
+ data=str_data[:num_lines],
501
+ step=step,
502
+ )
503
+
504
+ # log incorrect normalised predictions
505
+ str_data = np.asarray(str_data)
506
+ str_data_incorrect = str_data[str_data[:, -2] != str_data[:, -1]]
507
+ # log as a table with the appropriate headers
508
+ wandb_tracker.log_table(
509
+ table_name=f"incorrect_predictions/{prefix_pretty}-step-{cur_step_pretty}",
510
+ columns=["Target", "Pred", "Norm Target", "Norm Pred"],
511
+ data=str_data_incorrect[:num_lines],
512
+ step=step,
513
+ )
514
+
515
+
516
+ def convert_dataset_str_to_list(
517
+ dataset_names,
518
+ dataset_config_names,
519
+ splits=None,
520
+ text_column_names=None,
521
+ dataset_samples=None,
522
+ default_split="train",
523
+ ) -> List[Dict]:
524
+ """
525
+ Given three lists of dataset names, configs and splits, this function groups the corresponding
526
+ names/configs/splits. Each dataset is assigned a unique dictionary with these metadata values, and the
527
+ function returns a list of dictionaries, one for each dataset.
528
+ """
529
+ if isinstance(dataset_names, str):
530
+ dataset_names = dataset_names.split("+")
531
+ dataset_config_names = dataset_config_names.split("+") if dataset_config_names is not None else None
532
+ splits = splits.split("+") if splits is not None else None
533
+ text_column_names = text_column_names.split("+") if text_column_names is not None else None
534
+ dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
535
+
536
+ # basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
537
+ if dataset_config_names is not None and len(dataset_names) != len(dataset_config_names):
538
+ raise ValueError(
539
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
540
+ f" {len(dataset_config_names)} configs."
541
+ )
542
+
543
+ if splits is not None and len(splits) != len(dataset_names):
544
+ raise ValueError(
545
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
546
+ )
547
+
548
+ if text_column_names is not None and len(text_column_names) != len(dataset_names):
549
+ raise ValueError(
550
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
551
+ f" {len(text_column_names)} text column names."
552
+ )
553
+
554
+ if dataset_samples is not None:
555
+ if len(dataset_samples) != len(dataset_names):
556
+ raise ValueError(
557
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
558
+ f"{len(dataset_samples)} samples."
559
+ )
560
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
561
+ else:
562
+ dataset_samples = [None] * len(dataset_names)
563
+
564
+ dataset_config_names = (
565
+ dataset_config_names if dataset_config_names is not None else ["default" for _ in range(len(dataset_names))]
566
+ )
567
+ text_column_names = (
568
+ text_column_names if text_column_names is not None else ["text" for _ in range(len(dataset_names))]
569
+ )
570
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
571
+
572
+ dataset_names_dict = []
573
+ for i, ds_name in enumerate(dataset_names):
574
+ dataset_names_dict.append(
575
+ {
576
+ "name": ds_name,
577
+ "config": dataset_config_names[i],
578
+ "split": splits[i],
579
+ "text_column_name": text_column_names[i],
580
+ "samples": dataset_samples[i],
581
+ }
582
+ )
583
+ return dataset_names_dict
584
+
585
+
586
+ def load_multiple_datasets(
587
+ dataset_names: Union[List, str],
588
+ dataset_config_names: Union[List, str],
589
+ splits: Optional[Union[List, str]] = None,
590
+ text_column_names: Optional[List] = None,
591
+ sampling_rate: Optional[int] = 16000,
592
+ stopping_strategy: Optional[str] = "first_exhausted",
593
+ dataset_samples: Optional[Union[List, np.array]] = None,
594
+ streaming: Optional[bool] = True,
595
+ seed: Optional[int] = None,
596
+ accelerator: Optional[Accelerator] = None,
597
+ use_pseudo_labels: float = None,
598
+ **kwargs,
599
+ ) -> IterableDataset:
600
+ dataset_names_dict = convert_dataset_str_to_list(
601
+ dataset_names, dataset_config_names, splits, text_column_names, dataset_samples
602
+ )
603
+
604
+ if dataset_samples is not None:
605
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
606
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
607
+ else:
608
+ probabilities = None
609
+
610
+ all_datasets = []
611
+ # iterate over the datasets we want to interleave
612
+ for dataset_dict in tqdm(
613
+ dataset_names_dict,
614
+ desc="Combining datasets...",
615
+ disable=not accelerator.is_local_main_process if accelerator is not None else False,
616
+ ):
617
+ dataset = load_dataset(
618
+ dataset_dict["name"],
619
+ dataset_dict["config"],
620
+ split=dataset_dict["split"],
621
+ streaming=streaming,
622
+ **kwargs,
623
+ )
624
+ # resample to specified sampling rate
625
+ dataset = dataset.cast_column("audio", datasets.features.Audio(sampling_rate))
626
+ dataset_features = dataset.features.keys()
627
+ columns_to_keep = {"audio", "text"}
628
+
629
+ if dataset_dict["text_column_name"] not in dataset_features:
630
+ raise ValueError(
631
+ f"Text column name {dataset_dict['text_column_name']} not found in dataset"
632
+ f" '{dataset_dict['name']}'. Make sure to set `--text_column_name` to the"
633
+ f" correct text column - one of {', '.join(dataset_features)}."
634
+ )
635
+
636
+ # blanket renaming of all transcription columns to text
637
+ if dataset_dict["text_column_name"] != "text":
638
+ dataset = dataset.rename_column(dataset_dict["text_column_name"], "text")
639
+
640
+ if use_pseudo_labels:
641
+ if "whisper_transcript" not in dataset_features:
642
+ raise ValueError(
643
+ f"Pseudo-label column `whisper_transcript` not found in dataset {dataset_dict['name']}. Ensure"
644
+ "pseudo-labels are present in the dataset under this column name, or train directly on the text "
645
+ "labels by setting `--use_pseudo_labels=False` and defining the appropriate `--text_column_name`."
646
+ )
647
+ columns_to_keep.add("whisper_transcript")
648
+
649
+ if "condition_on_prev" in dataset_features:
650
+ columns_to_keep.add("condition_on_prev")
651
+
652
+ dataset_features = dataset.features.keys()
653
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
654
+ all_datasets.append(dataset)
655
+
656
+ if len(all_datasets) == 1:
657
+ # we have a single dataset so just return it as is
658
+ return all_datasets[0]
659
+
660
+ if streaming:
661
+ interleaved_dataset = interleave_datasets(
662
+ all_datasets,
663
+ stopping_strategy=stopping_strategy,
664
+ probabilities=probabilities,
665
+ seed=seed,
666
+ )
667
+ else:
668
+ interleaved_dataset = concatenate_datasets(all_datasets)
669
+
670
+ return interleaved_dataset
671
+
672
+
673
+ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
674
+ """Helper function to sort saved checkpoints from oldest to newest."""
675
+ ordering_and_checkpoint_path = []
676
+
677
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
678
+
679
+ for path in glob_checkpoints:
680
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
681
+ if regex_match is not None and regex_match.groups() is not None:
682
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
683
+
684
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
685
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
686
+ return checkpoints_sorted
687
+
688
+
689
+ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> None:
690
+ """Helper function to delete old checkpoints."""
691
+ if save_total_limit is None or save_total_limit <= 0:
692
+ return
693
+ # Check if we should delete older checkpoint(s)
694
+ checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
695
+ if len(checkpoints_sorted) <= save_total_limit:
696
+ return
697
+
698
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
699
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
700
+ for checkpoint in checkpoints_to_be_deleted:
701
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
702
+ shutil.rmtree(checkpoint, ignore_errors=True)
703
+
704
+
705
+ _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
706
+
707
+
708
+ def get_last_checkpoint(folder):
709
+ content = os.listdir(folder)
710
+ checkpoints = [
711
+ path
712
+ for path in content
713
+ if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
714
+ ]
715
+ if len(checkpoints) == 0:
716
+ return
717
+ return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
718
+
719
+
720
+ def get_parameter_names(model, forbidden_layer_types, forbidden_module=None):
721
+ """
722
+ Returns the names of the model parameters that are not inside a forbidden layer or forbidden module.
723
+ Can be used to get a subset of parameter names for decay masks, or to exclude parameters from an optimiser
724
+ (e.g. if the module is frozen).
725
+ """
726
+ result = []
727
+ for name, child in model.named_children():
728
+ result += [
729
+ f"{name}.{n}"
730
+ for n in get_parameter_names(child, forbidden_layer_types, forbidden_module)
731
+ if not (
732
+ isinstance(child, tuple(forbidden_layer_types))
733
+ or (child in tuple(forbidden_module) if forbidden_module is not None else False)
734
+ )
735
+ ]
736
+ # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
737
+ result += list(model._parameters.keys())
738
+ return result
739
+
740
+
741
+ def main():
742
+ # 1. Parse input arguments
743
+ # We keep distinct sets of args, for cleaner separation of model/data/training related args
744
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, DistillationTrainingArguments))
745
+
746
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
747
+ # If we pass only one argument to the script and it's the path to a json file,
748
+ # let's parse it to get our arguments.
749
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
750
+ else:
751
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
752
+
753
+ # 2. Initialize the accelerator
754
+ # We will let the accelerator handle device placement for us in this example
755
+ # We simply have to specify the training precision and any trackers being used
756
+ # We'll use the same dtype arguments as our JAX/Flax training script and convert
757
+ # it to accelerate format
758
+ if training_args.dtype == "float16":
759
+ mixed_precision = "fp16"
760
+ teacher_dtype = torch.float16
761
+ elif training_args.dtype == "bfloat16":
762
+ mixed_precision = "bf16"
763
+ teacher_dtype = torch.bfloat16
764
+ else:
765
+ mixed_precision = "no"
766
+ teacher_dtype = torch.float32
767
+
768
+ accelerator = Accelerator(
769
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
770
+ mixed_precision=mixed_precision,
771
+ log_with=training_args.report_to,
772
+ project_dir=training_args.output_dir,
773
+ )
774
+
775
+ accelerator.init_trackers(project_name=data_args.wandb_project)
776
+
777
+ # 3. Set-up basic logging
778
+ # Create one log on every process with the configuration for debugging
779
+ logging.basicConfig(
780
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
781
+ datefmt="%m/%d/%Y %H:%M:%S",
782
+ level=logging.INFO,
783
+ )
784
+ # Log a small summary on each proces
785
+ logger.warning(
786
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
787
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
788
+ )
789
+
790
+ # Set the verbosity to info of the Transformers logger (on main process only)
791
+ if accelerator.is_local_main_process:
792
+ datasets.utils.logging.set_verbosity_warning()
793
+ transformers.utils.logging.set_verbosity_info()
794
+ else:
795
+ datasets.utils.logging.set_verbosity_error()
796
+ transformers.utils.logging.set_verbosity_error()
797
+ logger.info("Training/evaluation parameters %s", training_args)
798
+
799
+ # 4. Detecting last checkpoint and eventually continue from last checkpoint
800
+ last_checkpoint = None
801
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
802
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
803
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
804
+ raise ValueError(
805
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
806
+ "Use --overwrite_output_dir to overcome."
807
+ )
808
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
809
+ logger.info(
810
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
811
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
812
+ )
813
+
814
+ # 5. Handle the repository creation
815
+ if accelerator.is_main_process:
816
+ if training_args.push_to_hub:
817
+ if training_args.hub_model_id is None:
818
+ repo_name = get_full_repo_name(
819
+ Path(training_args.output_dir).absolute().name,
820
+ token=training_args.hub_token,
821
+ )
822
+ else:
823
+ repo_name = training_args.hub_model_id
824
+ create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
825
+
826
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
827
+ if "wandb" not in gitignore:
828
+ gitignore.write("wandb\n")
829
+ elif training_args.output_dir is not None:
830
+ os.makedirs(training_args.output_dir, exist_ok=True)
831
+ accelerator.wait_for_everyone()
832
+
833
+ # 6. Load dataset - either streaming or non-streaming (offline)
834
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
835
+
836
+ # set seed for determinism
837
+ set_seed(training_args.seed)
838
+
839
+ if training_args.do_train:
840
+ raw_datasets["train"] = load_multiple_datasets(
841
+ data_args.train_dataset_name,
842
+ data_args.train_dataset_config_name,
843
+ splits=data_args.train_split_name,
844
+ text_column_names=data_args.text_column_name,
845
+ use_pseudo_labels=data_args.use_pseudo_labels,
846
+ streaming=data_args.streaming,
847
+ dataset_samples=data_args.train_dataset_samples,
848
+ seed=training_args.seed,
849
+ accelerator=accelerator,
850
+ cache_dir=data_args.dataset_cache_dir,
851
+ token=model_args.token,
852
+ )
853
+ raw_datasets_train_features = list(raw_datasets["train"].features.keys())
854
+
855
+ if training_args.do_eval:
856
+ dataset_names_dict = convert_dataset_str_to_list(
857
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
858
+ (
859
+ data_args.eval_dataset_config_name
860
+ if data_args.eval_dataset_config_name
861
+ else data_args.train_dataset_config_name
862
+ ),
863
+ splits=data_args.eval_split_name,
864
+ text_column_names=data_args.eval_text_column_name,
865
+ )
866
+ all_eval_splits = []
867
+ if len(dataset_names_dict) == 1:
868
+ # load a single eval set
869
+ dataset_dict = dataset_names_dict[0]
870
+ all_eval_splits.append("eval")
871
+ raw_datasets["eval"] = load_dataset(
872
+ dataset_dict["name"],
873
+ dataset_dict["config"],
874
+ split=dataset_dict["split"],
875
+ cache_dir=data_args.dataset_cache_dir,
876
+ token=model_args.token,
877
+ streaming=data_args.streaming,
878
+ )
879
+ if data_args.eval_text_column_name != "text":
880
+ raw_datasets["eval"] = raw_datasets["eval"].rename_column(data_args.eval_text_column_name, "text")
881
+ else:
882
+ # load multiple eval sets
883
+ for dataset_dict in dataset_names_dict:
884
+ if dataset_dict["name"] == "esb/diagnostic-dataset":
885
+ # for the ESB diagnostic dataset, the dataset name is effectively the config
886
+ pretty_name = f"{dataset_dict['config']}-diagnostic/{dataset_dict['split']}"
887
+ else:
888
+ pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
889
+ all_eval_splits.append(pretty_name)
890
+ raw_datasets[pretty_name] = load_dataset(
891
+ dataset_dict["name"],
892
+ dataset_dict["config"],
893
+ split=dataset_dict["split"],
894
+ cache_dir=data_args.dataset_cache_dir,
895
+ token=model_args.token,
896
+ streaming=data_args.streaming,
897
+ )
898
+ # make column names consistent (text, audio)
899
+ if dataset_dict["text_column_name"] != "text":
900
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
901
+ dataset_dict["text_column_name"], "text"
902
+ )
903
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
904
+ set(raw_datasets[pretty_name].features.keys()) - {"audio", "text"}
905
+ )
906
+
907
+ if not training_args.do_train and not training_args.do_eval:
908
+ raise ValueError(
909
+ "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
910
+ )
911
+
912
+ # 7. Load pretrained model, tokenizer, and feature extractor
913
+ config = WhisperConfig.from_pretrained(
914
+ (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
915
+ cache_dir=model_args.cache_dir,
916
+ revision=model_args.model_revision,
917
+ token=model_args.token,
918
+ )
919
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
920
+ (model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
921
+ cache_dir=model_args.cache_dir,
922
+ revision=model_args.model_revision,
923
+ token=model_args.token,
924
+ )
925
+ tokenizer = WhisperTokenizerFast.from_pretrained(
926
+ (model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
927
+ cache_dir=model_args.cache_dir,
928
+ use_fast=model_args.use_fast_tokenizer,
929
+ revision=model_args.model_revision,
930
+ token=model_args.token,
931
+ )
932
+
933
+ # override timestamp tokens until tokenizer issues are fixed in transformers
934
+ timestamps = [AddedToken("<|%.2f|>" % (i * 0.02), lstrip=False, rstrip=False) for i in range(1500 + 1)]
935
+ tokenizer.add_tokens(timestamps)
936
+
937
+ # The teacher model can safely be cast to the dtype of training since we don't
938
+ # update the params
939
+ teacher_model = WhisperForConditionalGeneration.from_pretrained(
940
+ model_args.teacher_model_name_or_path,
941
+ cache_dir=model_args.cache_dir,
942
+ token=model_args.token,
943
+ low_cpu_mem_usage=True,
944
+ torch_dtype=teacher_dtype,
945
+ attn_implementation=model_args.attn_implementation,
946
+ )
947
+
948
+ student_model = WhisperForConditionalGeneration.from_pretrained(
949
+ model_args.model_name_or_path,
950
+ config=config,
951
+ cache_dir=model_args.cache_dir,
952
+ revision=model_args.model_revision,
953
+ subfolder=model_args.subfolder,
954
+ token=model_args.token,
955
+ low_cpu_mem_usage=True,
956
+ attn_implementation=model_args.attn_implementation,
957
+ )
958
+
959
+ if student_model.config.decoder_start_token_id is None or teacher_model.config.decoder_start_token_id is None:
960
+ raise ValueError(
961
+ f"Make sure that `config.decoder_start_token_id` is correctly defined for both the "
962
+ f"student and teacher model. Got {student_model.config.decoder_start_token_id} for the "
963
+ f"student and {teacher_model.config.decoder_start_token_id} for the teacher."
964
+ )
965
+
966
+ # enable gradient checkpointing if necessary
967
+ if training_args.gradient_checkpointing:
968
+ student_model.gradient_checkpointing_enable()
969
+
970
+ def set_trainable_parameters(module, requires_grad=False):
971
+ for param in module.parameters():
972
+ param.requires_grad = requires_grad
973
+ module._requires_grad = requires_grad
974
+
975
+ # freeze student encoder if necessary
976
+ if training_args.freeze_encoder:
977
+ set_trainable_parameters(student_model.model.encoder, requires_grad=False)
978
+ student_model.model.encoder.gradient_checkpointing = False
979
+
980
+ if training_args.freeze_embed_positions:
981
+ # set_trainable_parameters(student_model.model.decoder.embed_tokens, requires_grad=False)
982
+ set_trainable_parameters(student_model.model.decoder.embed_positions, requires_grad=False)
983
+ if student_model.model.decoder.gradient_checkpointing:
984
+ logger.info(
985
+ "Disabling gradient checkpointing in the decoder since it's incompatible with `freeze_embed_positions`."
986
+ )
987
+
988
+ share_hidden_states = training_args.freeze_encoder and student_model.config.d_model == teacher_model.config.d_model
989
+ if share_hidden_states:
990
+ # tie the weights for the teacher encoder if we're freezing the student and it's the same as the teacher
991
+ teacher_model.model.encoder = student_model.model.encoder
992
+
993
+ if hasattr(teacher_model.generation_config, "is_multilingual") and teacher_model.generation_config.is_multilingual:
994
+ # We need to set the language and task ids for previously multilingual checkpoints
995
+ is_multilingual = True
996
+ tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task, predict_timestamps=False)
997
+ student_model.generation_config.update(
998
+ **{
999
+ "language": data_args.language,
1000
+ "task": data_args.task,
1001
+ }
1002
+ )
1003
+ elif data_args.language is not None:
1004
+ raise ValueError(
1005
+ "Setting language token for an English-only checkpoint is not permitted. The language argument should "
1006
+ "only be set for multilingual checkpoints."
1007
+ )
1008
+ else:
1009
+ is_multilingual = False
1010
+
1011
+ # 8. Create a single speech processor - make sure all processes wait until data is saved
1012
+ if accelerator.is_main_process:
1013
+ feature_extractor.save_pretrained(training_args.output_dir)
1014
+ tokenizer.save_pretrained(training_args.output_dir)
1015
+ # save the config and generation config as well
1016
+ config.save_pretrained(training_args.output_dir)
1017
+ student_model.generation_config.save_pretrained(training_args.output_dir)
1018
+
1019
+ accelerator.wait_for_everyone()
1020
+ processor = WhisperProcessor.from_pretrained(training_args.output_dir)
1021
+
1022
+ # 9. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
1023
+ # so we just need to set the correct target sampling rate.
1024
+ sampling_rate = feature_extractor.sampling_rate
1025
+ raw_datasets = raw_datasets.cast_column(
1026
+ data_args.audio_column_name,
1027
+ datasets.features.Audio(sampling_rate=sampling_rate),
1028
+ )
1029
+
1030
+ # 10. Preprocessing the datasets: we need to read the audio files as arrays and tokenize the targets.
1031
+ # 10.1: Define the pre-processing constants
1032
+ max_input_length = int(data_args.max_duration_in_seconds * sampling_rate)
1033
+ min_input_length = int(data_args.min_duration_in_seconds * sampling_rate)
1034
+ max_label_length = (
1035
+ data_args.max_label_length if data_args.max_label_length is not None else student_model.config.max_length
1036
+ )
1037
+
1038
+ timestamp_probability = data_args.timestamp_probability
1039
+ condition_on_prev_probability = data_args.condition_on_prev_probability
1040
+ return_timestamps = data_args.return_timestamps if timestamp_probability > 0 else False
1041
+
1042
+ timestamp_ids = tokenizer.timestamp_ids()
1043
+ timestamp_begin = tokenizer.all_special_ids[-1]
1044
+ timestamp_position = 3 if is_multilingual else 1
1045
+
1046
+ decoder_start_token_id = student_model.config.decoder_start_token_id # <|startoftranscript|>
1047
+ decoder_prev_token_id = tokenizer.all_special_ids[-3] # <|startofprev|>
1048
+ prompt_cutoff_length = max_label_length // 2
1049
+
1050
+ num_workers = data_args.preprocessing_num_workers
1051
+ dataloader_num_workers = training_args.dataloader_num_workers
1052
+ prefetch_factor = training_args.dataloader_prefetch_factor
1053
+
1054
+ metric = evaluate.load("wer")
1055
+ normalizer = (
1056
+ BasicTextNormalizer()
1057
+ if data_args.language is not None
1058
+ else EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
1059
+ )
1060
+ wer_threshold = data_args.wer_threshold
1061
+ use_pseudo_labels = data_args.use_pseudo_labels
1062
+ train_text_column_name = "whisper_transcript" if use_pseudo_labels else "text"
1063
+
1064
+ # 10.2: filter based on maximum number of training/evaluation samples
1065
+ if training_args.do_train and data_args.max_train_samples is not None:
1066
+ raw_datasets["train"] = (
1067
+ raw_datasets["train"].take(data_args.max_train_samples)
1068
+ if data_args.streaming
1069
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
1070
+ )
1071
+
1072
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1073
+ for eval_split in all_eval_splits:
1074
+ raw_datasets[eval_split] = (
1075
+ raw_datasets[eval_split].take(data_args.max_eval_samples)
1076
+ if data_args.streaming
1077
+ else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1078
+ )
1079
+
1080
+ # 10.3: filter training data based on WER threshold -> this is KEY to good distillation performance
1081
+ def is_wer_in_range(ground_truth, whisper_transcript):
1082
+ norm_ground_truth = normalizer(ground_truth)
1083
+ if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
1084
+ # filter entirely upper-case transcriptions: these are erroneous generations from large-v3
1085
+ return False
1086
+ elif len(norm_ground_truth) > 0 and whisper_transcript is not None:
1087
+ norm_whisper_transcript = normalizer(whisper_transcript)
1088
+ wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
1089
+ return wer < wer_threshold
1090
+ else:
1091
+ # filter automatically since we can't know the WER
1092
+ return False
1093
+
1094
+ filter_by_wer_threshold = partial(
1095
+ raw_datasets["train"].filter,
1096
+ function=is_wer_in_range,
1097
+ input_columns=["text", "whisper_transcript"],
1098
+ )
1099
+
1100
+ if wer_threshold is not None and use_pseudo_labels:
1101
+ with accelerator.main_process_first():
1102
+ raw_datasets["train"] = (
1103
+ filter_by_wer_threshold(num_proc=num_workers, desc="filtering train dataset by wer")
1104
+ if not data_args.streaming
1105
+ else filter_by_wer_threshold()
1106
+ )
1107
+
1108
+ # 10.4: pre-process training/evaluation datasets
1109
+ def prepare_train_dataset(batch):
1110
+ """
1111
+ Pre-process the raw dataset in a three stage process:
1112
+ 1. Convert the audio arrays to log-mel spectrogram inputs
1113
+ 2. Possibly filter the timestamp tokens from the token ids (depending on the timestamp probability)
1114
+ 3. Possibly add prompt tokens if conditioning on previous text (depending on the conditioning probability)
1115
+ """
1116
+ # process audio input
1117
+ audio = [sample["array"] for sample in batch["audio"]]
1118
+ inputs = feature_extractor(audio, sampling_rate=sampling_rate)
1119
+ batch["input_features"] = inputs.input_features
1120
+ batch["input_length"] = [len(sample) for sample in audio]
1121
+
1122
+ # process text targets - for training these are the Whisper-generated pseudo-labels
1123
+ input_str_batched = batch[train_text_column_name]
1124
+ condition_on_prev_batched = batch.get("condition_on_prev", len(input_str_batched) * [None])
1125
+
1126
+ all_token_ids = []
1127
+ all_token_ids_unprompted = []
1128
+ for prev_ids, input_str in zip(condition_on_prev_batched, input_str_batched):
1129
+ token_ids = tokenizer(input_str, add_special_tokens=not use_pseudo_labels).input_ids
1130
+
1131
+ # check whether we have timestamps in the PLs and filter if required
1132
+ has_timestamps = len(set(token_ids) & set(timestamp_ids)) > 0
1133
+ if has_timestamps:
1134
+ # sample from binomial distribution to get probability of training on timestamps
1135
+ predict_timestamps = bool(np.random.binomial(1, timestamp_probability))
1136
+ if not predict_timestamps:
1137
+ # filter timestamps and insert the <|notimestamps|> task token
1138
+ token_ids = [token for token in token_ids if token < timestamp_begin]
1139
+ token_ids.insert(timestamp_position, timestamp_begin)
1140
+
1141
+ all_token_ids_unprompted.append(token_ids)
1142
+ # check whether to condition on previous text - we do this with probability condition_on_prev_probability
1143
+ condition_on_prev = bool(np.random.binomial(1, condition_on_prev_probability))
1144
+ if not condition_on_prev:
1145
+ prev_ids = None
1146
+ elif "condition_on_prev" not in batch and len(all_token_ids_unprompted) > 1:
1147
+ # prompt ids are the penultimate token ids in the batch
1148
+ prev_ids = all_token_ids_unprompted[-2]
1149
+
1150
+ if prev_ids is not None:
1151
+ if has_timestamps and not predict_timestamps:
1152
+ # filter timestamp ids from prompt when not predicting timestamps
1153
+ prev_ids = [token for token in prev_ids if token < timestamp_begin]
1154
+
1155
+ # check that the length of the prompt does not exceed more than half the max label length (224)
1156
+ if len(prev_ids) > prompt_cutoff_length:
1157
+ prev_ids = prev_ids[-prompt_cutoff_length + 1 :]
1158
+ prev_ids = [decoder_prev_token_id] + prev_ids
1159
+
1160
+ # and that the total length of the labels does not exceed the max label length (448)
1161
+ if len(prev_ids + token_ids) > max_label_length:
1162
+ trim_length = len(prev_ids + token_ids) - max_label_length + 1
1163
+ prev_ids = prev_ids[trim_length:]
1164
+ prev_ids = [decoder_prev_token_id] + prev_ids
1165
+
1166
+ token_ids = prev_ids + token_ids
1167
+
1168
+ all_token_ids.append(token_ids)
1169
+
1170
+ batch["labels"] = all_token_ids
1171
+ return batch
1172
+
1173
+ def prepare_eval_dataset(batch):
1174
+ # process audio input
1175
+ sample = batch["audio"]
1176
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1177
+ batch["input_features"] = inputs.input_features[0]
1178
+ batch["input_length"] = len(sample["array"])
1179
+
1180
+ # process targets - for evaluation these are the ground-truth transcriptions
1181
+ input_str = batch["text"]
1182
+ batch["labels"] = tokenizer(input_str).input_ids
1183
+ return batch
1184
+
1185
+ vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
1186
+ if training_args.do_train:
1187
+ # with streaming mode we can only have 1 worker, whereas with non-streaming
1188
+ # we can use `num_workers` (which is much faster)
1189
+ # We gate the pre-processing function accordingly
1190
+ map_fn_train = partial(
1191
+ raw_datasets["train"].map,
1192
+ function=prepare_train_dataset,
1193
+ remove_columns=raw_datasets_train_features,
1194
+ batched=True,
1195
+ batch_size=data_args.preprocessing_batch_size,
1196
+ )
1197
+ with accelerator.main_process_first():
1198
+ vectorized_datasets["train"] = (
1199
+ map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
1200
+ if not data_args.streaming
1201
+ else map_fn_train()
1202
+ )
1203
+ if training_args.do_eval:
1204
+ for eval_split in all_eval_splits:
1205
+ raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1206
+ map_fn_eval = partial(
1207
+ raw_datasets[eval_split].map, function=prepare_eval_dataset, remove_columns=raw_datasets_eval_features
1208
+ )
1209
+ with accelerator.main_process_first():
1210
+ vectorized_datasets[eval_split] = (
1211
+ map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
1212
+ if not data_args.streaming
1213
+ else map_fn_eval()
1214
+ )
1215
+
1216
+ # 10.5: Filter training data with inputs longer than `max_input_length`
1217
+ def is_audio_in_length_range(length):
1218
+ return min_input_length < length < max_input_length
1219
+
1220
+ filter_by_audio_fn = partial(
1221
+ vectorized_datasets.filter, function=is_audio_in_length_range, input_columns=["input_length"]
1222
+ )
1223
+ with accelerator.main_process_first():
1224
+ vectorized_datasets = (
1225
+ filter_by_audio_fn(num_proc=num_workers, desc="filtering train dataset by audio length")
1226
+ if not data_args.streaming
1227
+ else filter_by_audio_fn()
1228
+ )
1229
+
1230
+ # 10.6: Filter training data with labels longer than `max_label_length`
1231
+ def is_labels_in_length_range(labels):
1232
+ return 0 < len(labels) <= max_label_length
1233
+
1234
+ filter_by_labels_fn = partial(
1235
+ vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1236
+ )
1237
+ with accelerator.main_process_first():
1238
+ vectorized_datasets = (
1239
+ filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1240
+ if not data_args.streaming
1241
+ else filter_by_labels_fn()
1242
+ )
1243
+
1244
+ # Pre-processing complete!
1245
+ # For large datasets it is advised to run the preprocessing on a
1246
+ # single machine first with `--preprocessing_only` since there will mostly likely
1247
+ # be a timeout when running the script in distributed mode.
1248
+ # In a second step, `--preprocessing_only` can then be set to `False` to load the
1249
+ # cached dataset
1250
+ if data_args.preprocessing_only:
1251
+ if data_args.streaming:
1252
+ raise ValueError(
1253
+ "When using streaming mode, dataset pre-processing is performed on the fly, hence there is no notion"
1254
+ "of a cached pre-processed dataset. Remove the argument `--preprocessing_only` to run pre-processing "
1255
+ "on the fly with streaming mode."
1256
+ )
1257
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1258
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1259
+ return
1260
+
1261
+ # 11. Define Evaluation Metrics
1262
+ def compute_metrics(preds, labels):
1263
+ # replace padded labels by the padding token
1264
+ for idx in range(len(labels)):
1265
+ labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1266
+
1267
+ pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
1268
+ # we do not want to group tokens when computing the metrics
1269
+ label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1270
+ wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
1271
+
1272
+ # normalize everything and re-compute the WER
1273
+ norm_pred_str = [normalizer(pred) for pred in pred_str]
1274
+ norm_label_str = [normalizer(label) for label in label_str]
1275
+ # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
1276
+ pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1277
+ label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1278
+ # filtering step to only evaluate the samples that correspond to non-zero normalized references:
1279
+ norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1280
+ norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1281
+
1282
+ wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
1283
+ return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1284
+
1285
+ # 12. Define Training Schedule
1286
+ # Store some constants
1287
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1288
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
1289
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1290
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1291
+
1292
+ if not data_args.streaming and training_args.max_steps < 0:
1293
+ num_epochs = int(training_args.num_train_epochs)
1294
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1295
+ total_train_steps = steps_per_epoch * num_epochs
1296
+ elif training_args.max_steps > 0:
1297
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
1298
+ total_train_steps = int(training_args.max_steps)
1299
+ if not data_args.streaming:
1300
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1301
+ num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1302
+ else:
1303
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1304
+ num_epochs = sys.maxsize
1305
+ steps_per_epoch = total_train_steps
1306
+ else:
1307
+ raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1308
+
1309
+ if training_args.eval_steps is None:
1310
+ logger.info(
1311
+ f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1312
+ )
1313
+ eval_steps = steps_per_epoch
1314
+ else:
1315
+ eval_steps = training_args.eval_steps
1316
+
1317
+ # 13. Define optimizer, LR scheduler, collator
1318
+ decay_parameters = get_parameter_names(
1319
+ student_model,
1320
+ [nn.LayerNorm],
1321
+ forbidden_module=[student_model.model.encoder] if training_args.freeze_encoder else None,
1322
+ )
1323
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
1324
+ optimizer_grouped_parameters = [
1325
+ {
1326
+ "params": [param for name, param in student_model.named_parameters() if name in decay_parameters],
1327
+ "weight_decay": training_args.weight_decay,
1328
+ },
1329
+ {
1330
+ "params": [param for name, param in student_model.named_parameters() if name not in decay_parameters],
1331
+ "weight_decay": 0.0,
1332
+ },
1333
+ ]
1334
+ optimizer = torch.optim.AdamW(
1335
+ params=optimizer_grouped_parameters,
1336
+ lr=training_args.learning_rate,
1337
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
1338
+ eps=training_args.adam_epsilon,
1339
+ )
1340
+
1341
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1342
+ lr_scheduler = get_scheduler(
1343
+ name=training_args.lr_scheduler_type,
1344
+ optimizer=optimizer,
1345
+ num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1346
+ num_training_steps=total_train_steps * accelerator.num_processes,
1347
+ )
1348
+
1349
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
1350
+ processor=processor,
1351
+ decoder_start_token_id=decoder_start_token_id,
1352
+ decoder_prev_token_id=decoder_prev_token_id,
1353
+ input_padding="longest",
1354
+ target_padding="max_length",
1355
+ max_target_length=max_label_length,
1356
+ )
1357
+
1358
+ # 14. Define generation arguments - we need to do this before we wrap the models in DDP
1359
+ # so that we can still access the configs
1360
+ num_beams = (
1361
+ training_args.generation_num_beams
1362
+ if training_args.generation_num_beams is not None
1363
+ else getattr(student_model.generation_config, "num_beams", 1)
1364
+ )
1365
+
1366
+ gen_kwargs = {
1367
+ "max_length": max_label_length,
1368
+ "num_beams": num_beams,
1369
+ "return_timestamps": return_timestamps,
1370
+ }
1371
+ if is_multilingual:
1372
+ # forcing the language and task tokens helps multilingual models in their generations
1373
+ gen_kwargs.update(
1374
+ {
1375
+ "language": data_args.language,
1376
+ "task": data_args.task,
1377
+ }
1378
+ )
1379
+
1380
+ # 15. Prepare everything with accelerate
1381
+ student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1382
+ student_model, teacher_model, optimizer, lr_scheduler
1383
+ )
1384
+
1385
+ def kl_divergence(target_distribution, log_predicted_distribution, labels):
1386
+ kl_loss = nn.KLDivLoss(reduction="none")
1387
+ divergence = kl_loss(log_predicted_distribution, target_distribution)
1388
+ # ignore padded tokens from divergence, i.e. where labels are not set to -100
1389
+ padding_mask = labels >= 0
1390
+ padding_mask = padding_mask.unsqueeze(-1)
1391
+ divergence = divergence * padding_mask
1392
+ # take the average over the mini-batch
1393
+ divergence = divergence.sum() / padding_mask.sum()
1394
+ return divergence
1395
+
1396
+ # Define gradient update step fn
1397
+ def train_step(
1398
+ batch,
1399
+ temperature=2.0,
1400
+ ):
1401
+ student_model.train()
1402
+ teacher_model.eval()
1403
+
1404
+ student_outputs = student_model(**batch)
1405
+ with torch.no_grad():
1406
+ if share_hidden_states:
1407
+ # if the student and teacher share the same frozen encoder then we don't have to recompute the
1408
+ # encoder hidden-states for the teacher model, we can just re-use from the student
1409
+ encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1410
+ teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1411
+ else:
1412
+ # do the full forward pass for the teacher model (encoder + decoder)
1413
+ teacher_outputs = teacher_model(**batch)
1414
+
1415
+ # CE (data) loss
1416
+ ce_loss = student_outputs.loss
1417
+ # rescale distribution by temperature to ensure gradients scale correctly
1418
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits / temperature, dim=-1)
1419
+ # log softmax of student predictions for numerical stability
1420
+ student_distribution = nn.functional.log_softmax(student_outputs.logits / temperature, dim=-1)
1421
+ # KL-divergence loss (scaled by temperature)
1422
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"]) * temperature**2
1423
+
1424
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1425
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1426
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1427
+ return loss, metrics
1428
+
1429
+ # Define eval fn
1430
+ def eval_step(batch):
1431
+ student_model.eval()
1432
+ teacher_model.eval()
1433
+
1434
+ with torch.no_grad():
1435
+ student_outputs = student_model(**batch)
1436
+ if share_hidden_states:
1437
+ encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1438
+ teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1439
+ else:
1440
+ teacher_outputs = teacher_model(**batch)
1441
+
1442
+ # CE (data) loss
1443
+ ce_loss = student_outputs.loss
1444
+
1445
+ # log softmax / softmax for numerical stability
1446
+ student_distribution = nn.functional.log_softmax(student_outputs.logits, dim=-1)
1447
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits, dim=-1)
1448
+ # temperature is always 1 for eval
1449
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"])
1450
+
1451
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1452
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1453
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1454
+ return metrics
1455
+
1456
+ def generate_step(batch):
1457
+ student_model.eval()
1458
+ output_ids = accelerator.unwrap_model(student_model).generate(batch["input_features"], **gen_kwargs)
1459
+ output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
1460
+ return output_ids
1461
+
1462
+ logger.info("***** Running training *****")
1463
+ logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1464
+ if not data_args.streaming:
1465
+ logger.info(f" Num epochs = {num_epochs}")
1466
+ logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1467
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1468
+ logger.info(
1469
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1470
+ )
1471
+ logger.info(f" Total optimization steps = {total_train_steps}")
1472
+
1473
+ # ======================== Training ================================
1474
+ train_time = 0
1475
+ train_start = time.time()
1476
+ steps_trained_progress_bar = tqdm(
1477
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1478
+ )
1479
+ continue_training = True
1480
+ epochs_trained = 0
1481
+ cur_step = 0
1482
+
1483
+ checkpoint = None
1484
+ if training_args.resume_from_checkpoint is not None:
1485
+ checkpoint = training_args.resume_from_checkpoint
1486
+ elif last_checkpoint is not None:
1487
+ checkpoint = last_checkpoint
1488
+
1489
+ if checkpoint is not None:
1490
+ accelerator.load_state(checkpoint)
1491
+ # Find num steps and epoch from saved state string pattern
1492
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1493
+ match = re.search(pattern, checkpoint)
1494
+ cur_step = int(match.group(1))
1495
+ epochs_trained = int(match.group(2))
1496
+
1497
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1498
+ logger.info(f" Continuing training from epoch {epochs_trained}")
1499
+ logger.info(f" Continuing training from global step {cur_step}")
1500
+
1501
+ steps_trained_progress_bar.update(cur_step)
1502
+
1503
+ for epoch in range(0, epochs_trained):
1504
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1505
+
1506
+ if not data_args.streaming and training_args.max_steps < 0:
1507
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
1508
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1509
+ else:
1510
+ # Currently we don't know how many steps we've taken in the current epoch
1511
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
1512
+ # This is "good enough" for our purposes but not fully correct
1513
+ resume_step = None
1514
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1515
+ else:
1516
+ resume_step = None
1517
+
1518
+ for epoch in range(epochs_trained, num_epochs):
1519
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1520
+ train_dataloader = DataLoader(
1521
+ vectorized_datasets["train"],
1522
+ collate_fn=data_collator,
1523
+ batch_size=per_device_train_batch_size,
1524
+ num_workers=dataloader_num_workers,
1525
+ prefetch_factor=prefetch_factor,
1526
+ pin_memory=training_args.dataloader_pin_memory,
1527
+ )
1528
+ train_dataloader = accelerator.prepare(train_dataloader)
1529
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1530
+ train_dataloader.dataset.set_epoch(epoch)
1531
+
1532
+ if resume_step is not None:
1533
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1534
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1535
+ resume_step = None
1536
+
1537
+ for batch in train_dataloader:
1538
+ with accelerator.accumulate(student_model):
1539
+ loss, train_metric = train_step(batch, temperature=training_args.temperature)
1540
+ accelerator.backward(loss)
1541
+ if accelerator.sync_gradients:
1542
+ accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
1543
+ optimizer.step()
1544
+ lr_scheduler.step()
1545
+ optimizer.zero_grad()
1546
+
1547
+ # Check if the accelerator has performed an optimization step behind the scenes
1548
+ if accelerator.sync_gradients:
1549
+ steps_trained_progress_bar.update(1)
1550
+ cur_step += 1
1551
+
1552
+ if cur_step % training_args.logging_steps == 0:
1553
+ steps_trained_progress_bar.write(
1554
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1555
+ f" {train_metric['loss']}, Learning Rate:"
1556
+ f" {lr_scheduler.get_last_lr()[0]})"
1557
+ )
1558
+ log_metric(
1559
+ accelerator,
1560
+ metrics=train_metric,
1561
+ learning_rate=lr_scheduler.get_last_lr()[0],
1562
+ train_time=train_time + time.time() - train_start,
1563
+ step=cur_step,
1564
+ epoch=epoch,
1565
+ prefix="train",
1566
+ )
1567
+
1568
+ # save checkpoint and weights after each save_steps and at the end of training
1569
+ if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
1570
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1571
+ accelerator.save_state(output_dir=intermediate_dir)
1572
+ accelerator.wait_for_everyone()
1573
+ if accelerator.is_main_process:
1574
+ rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1575
+
1576
+ if training_args.push_to_hub:
1577
+ upload_folder(
1578
+ folder_path=training_args.output_dir,
1579
+ repo_id=repo_name,
1580
+ repo_type="model",
1581
+ commit_message=f"Saving train state of step {cur_step}",
1582
+ )
1583
+
1584
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1585
+ train_time += time.time() - train_start
1586
+ student_model.eval()
1587
+ # ======================== Evaluating ==============================
1588
+ for eval_split in all_eval_splits:
1589
+ eval_metrics = []
1590
+ eval_preds = []
1591
+ eval_labels = []
1592
+ eval_start = time.time()
1593
+
1594
+ validation_dataloader = DataLoader(
1595
+ vectorized_datasets[eval_split],
1596
+ collate_fn=data_collator,
1597
+ batch_size=per_device_eval_batch_size,
1598
+ drop_last=False,
1599
+ num_workers=dataloader_num_workers,
1600
+ prefetch_factor=prefetch_factor,
1601
+ pin_memory=training_args.dataloader_pin_memory,
1602
+ )
1603
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1604
+
1605
+ for batch in tqdm(
1606
+ validation_dataloader,
1607
+ desc=f"Evaluating {eval_split}...",
1608
+ position=2,
1609
+ disable=not accelerator.is_local_main_process,
1610
+ ):
1611
+ # Model forward
1612
+ eval_metric = eval_step(batch)
1613
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1614
+ eval_metrics.append(eval_metric)
1615
+
1616
+ # generation
1617
+ if training_args.predict_with_generate:
1618
+ generated_ids = generate_step(batch)
1619
+ # Gather all predictions and targets
1620
+ generated_ids, labels = accelerator.gather_for_metrics(
1621
+ (generated_ids, batch["labels"])
1622
+ )
1623
+ eval_preds.extend(generated_ids)
1624
+ eval_labels.extend(labels)
1625
+
1626
+ eval_time = time.time() - eval_start
1627
+ # normalize eval metrics
1628
+ eval_metrics = {
1629
+ key: torch.mean(torch.stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
1630
+ }
1631
+
1632
+ # compute WER metric
1633
+ wer_desc = ""
1634
+ if training_args.predict_with_generate:
1635
+ wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
1636
+ eval_preds, eval_labels
1637
+ )
1638
+ eval_metrics.update(wer_metric)
1639
+ wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
1640
+ log_pred(
1641
+ accelerator,
1642
+ pred_str,
1643
+ label_str,
1644
+ norm_pred_str,
1645
+ norm_label_str,
1646
+ step=cur_step,
1647
+ prefix=eval_split,
1648
+ )
1649
+
1650
+ # Print metrics and update progress bar
1651
+ steps_trained_progress_bar.write(
1652
+ f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1653
+ f" {wer_desc})"
1654
+ )
1655
+
1656
+ log_metric(
1657
+ accelerator,
1658
+ metrics=eval_metrics,
1659
+ train_time=eval_time,
1660
+ step=cur_step,
1661
+ epoch=epoch,
1662
+ prefix=eval_split,
1663
+ )
1664
+
1665
+ # flush the train metrics
1666
+ train_start = time.time()
1667
+
1668
+ # break condition
1669
+ if cur_step == total_train_steps:
1670
+
1671
+ # un-wrap student model for save
1672
+ student_model = accelerator.unwrap_model(student_model)
1673
+ student_model.save_pretrained(training_args.output_dir)
1674
+
1675
+ if training_args.push_to_hub:
1676
+ upload_folder(
1677
+ folder_path=training_args.output_dir,
1678
+ repo_id=repo_name,
1679
+ repo_type="model",
1680
+ commit_message=f"Saving final weights of step {cur_step}",
1681
+ )
1682
+
1683
+ continue_training = False
1684
+ break
1685
+
1686
+ if not continue_training:
1687
+ break
1688
+
1689
+ accelerator.end_training()
1690
+
1691
+
1692
+ if __name__ == "__main__":
1693
+ main()
.ipynb_checkpoints/setup-checkpoint.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+
18
+ import setuptools
19
+
20
+ _deps = [
21
+ "torch>=1.10",
22
+ "transformers>=4.35.1",
23
+ "datasets[audio]>=2.14.7",
24
+ "accelerate>=0.24.1",
25
+ "jiwer",
26
+ "evaluate>=0.4.1",
27
+ "wandb",
28
+ "tensorboard",
29
+ "nltk",
30
+ ]
31
+
32
+ _extras_dev_deps = [
33
+ "ruff==0.1.5",
34
+ ]
35
+
36
+ here = os.path.abspath(os.path.dirname(__file__))
37
+
38
+ with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
39
+ long_description = f.read()
40
+
41
+ setuptools.setup(
42
+ name="distil_whisper",
43
+ description="Toolkit for distilling OpenAI's Whisper model.",
44
+ long_description=long_description,
45
+ long_description_content_type="text/markdown",
46
+ packages=setuptools.find_packages(),
47
+ install_requires=_deps,
48
+ extras_require={
49
+ "dev": [_extras_dev_deps],
50
+ },
51
+ )
52
+
Makefile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ check_dirs := .
2
+
3
+ quality:
4
+ black --check $(check_dirs)
5
+ ruff $(check_dirs)
6
+
7
+ style:
8
+ black $(check_dirs)
9
+ ruff $(check_dirs) --fix
README.md ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Training Distil-Whisper
2
+
3
+ This sub-folder contains all the scripts required to train a Distil-Whisper model in your choice of language. They are
4
+ slightly modified from the original scripts used to distill Whisper for English ASR (as-per the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)).
5
+ The main difference is that these scripts are written in [PyTorch](https://pytorch.org), whereas the original scripts
6
+ are in [JAX](https://jax.readthedocs.io/en/latest/#)/[Flax](https://flax.readthedocs.io/en/latest/). These scripts are
7
+ also made to be easier to run end-to-end, whereas the original scripts require more steps and are somewhat hard-coded
8
+ for English ASR. Both sets of scripts achieve equivalent downstream results when the hyper-parameters are set equal.
9
+
10
+ If you are interested in reproducing the original Distil-Whisper checkpoints, we refer you to the sub-folder [Flax Training](./flax/README.md).
11
+ Otherwise, if you wish to distill Whisper on your own language/dataset, we recommend you use these scripts for ease of use
12
+ and the configurability they provide.
13
+
14
+ Reproducing the Distil-Whisper project requires four stages to be completed in successive order:
15
+
16
+ 1. [Pseudo-labelling](#1-pseudo-labelling)
17
+ 2. [Initialisation](#2-initialisation)
18
+ 3. [Training](#3-training)
19
+ 4. [Evaluation](#4-evaluation)
20
+
21
+ This README is partitioned according to the four stages. Each section provides a minimal example for running the
22
+ scripts used in the project. We will use a running example of distilling the Whisper model for Hindi speech recognition
23
+ on the Common Voice dataset. Note that this dataset only contains ~20 hours of audio data. Thus, it can be run extremely
24
+ quickly, but does not provide sufficient data to achieve optimal performance. We recommend training on upwards of 1000
25
+ hours of data should you want to match the performance of Whisper on high-resource languages.
26
+
27
+ ## Requirements
28
+
29
+ The Distil-Whisper training code is written in [PyTorch](https://pytorch.org) and [Accelerate](https://huggingface.co/docs/accelerate/index).
30
+ It heavily leverages the Whisper implementation in [🤗 Transformers](https://github.com/huggingface/transformers) for both
31
+ training and inference.
32
+
33
+ The instructions for installing the package are as follows:
34
+ 1. Install PyTorch from the [official instructions](https://pytorch.org/get-started/locally/), ensuring you install the correct version for your hardware and CUDA version.
35
+ 2. Fork the `distil-whisper` repository by clicking on the [fork](https://github.com/huggingface/distil-whisper/fork) button on the reopsitory's page
36
+ 3. Clone the `distil-whisper` repository and add the base repository as a remote. This will allow you to "pull" any upstream changes that are made to the base repository:
37
+
38
+ ```bash
39
+ git clone https://github.com/<your GitHub handle>/distil-whisper.git
40
+ cd distil-whisper
41
+ git remote add upstream https://github.com/huggingface/distil-whisper.git
42
+ ```
43
+ 4. pip install the required packages from the [setup.py](./setup.py) file:
44
+ ```bash
45
+ cd training
46
+ pip install -e .
47
+ cd ../..
48
+ ```
49
+
50
+ 5. Configure Accelerate by running the following command. Note that you should set the number of GPUs you wish to use for distillation, and also the data type (dtype) to your preferred dtype for training/inference (e.g. `bfloat16` on A100 GPUs, `float16` on V100 GPUs, etc.):
51
+
52
+ ```bash
53
+ accelerate config
54
+ ```
55
+
56
+ 6. The last thing we need to do is link our Hugging Face account so that we can pull/push model repositories on the Hub. This will allow us to save our final distilled weights on the Hub so that we can share them with the community. Run the command:
57
+
58
+ ```bash
59
+ git config --global credential.helper store
60
+ huggingface-cli login
61
+ ```
62
+ And then enter an authentication token from https://huggingface.co/settings/tokens. Create a new token if you do not have one already. You should make sure that this token has "write" privileges.
63
+
64
+ To confirm that you have a working environment, first accept the terms of use of the Common Voice 16.1 dataset on the Hub: https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1
65
+
66
+ You can run the following code cell to stream one sample of data from the Common Voice dataset, and check that you can
67
+ perform inference using the "tiny" Whisper model:
68
+
69
+ ```python
70
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
71
+ from datasets import load_dataset, Audio
72
+
73
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", low_cpu_mem_usage=True)
74
+ processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
75
+
76
+ model.to("cuda")
77
+
78
+ common_voice = load_dataset("mozilla-foundation/common_voice_16_1", "en", split="validation", streaming=True)
79
+ common_voice = common_voice.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
80
+
81
+ inputs = processor(next(iter(common_voice))["audio"]["array"], sampling_rate=16000, return_tensors="pt")
82
+ input_features = inputs.input_features
83
+
84
+ generated_ids = model.generate(input_features.to("cuda"), max_new_tokens=128)
85
+ pred_text = processor.decode(generated_ids[0], skip_special_tokens=True)
86
+
87
+ print("Pred text:", pred_text)
88
+ print("Environment set up successful?", generated_ids.shape[-1] == 20)
89
+ ```
90
+
91
+ ## 1. Pseudo-Labelling
92
+
93
+ The python script [`run_pseudo_labelling.py`](run_pseudo_labelling.py) is a flexible inference script that can be used
94
+ to generate pseudo-labels under a range of settings, including using both greedy and beam-search. It is also compatible
95
+ with [🤗 Datasets](https://github.com/huggingface/datasets) *streaming mode*, allowing users to load massive audio
96
+ datasets with **no disk space requirements**. For more information on streaming mode, the reader is referred to the
97
+ blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#streaming-mode-the-silver-bullet).
98
+
99
+ > As of the latest Distil-Whisper release, [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3), this
100
+ pseudo-labelling script also performs the added operation of concatenating (or packing) the audio inputs to 30-seconds.
101
+ Not only does this lead to a WER improvement when using sequential long-form decoding algorithm, but concatenating audios
102
+ to 30-seconds also improves the throughput during training, since the amount of zero-padding on the audio inputs is minimised.
103
+
104
+ The following script demonstrates how to pseudo-label the Hindi split of the Common Voice 16.1 dataset with greedy sampling:
105
+
106
+ ```bash
107
+ #!/usr/bin/env bash
108
+
109
+ accelerate launch run_pseudo_labelling.py \
110
+ --model_name_or_path "openai/whisper-large-v3" \
111
+ --dataset_name "mozilla-foundation/common_voice_16_1" \
112
+ --dataset_config_name "hi" \
113
+ --dataset_split_name "train+validation+test" \
114
+ --text_column_name "sentence" \
115
+ --id_column_name "path" \
116
+ --output_dir "./common_voice_16_1_hi_pseudo_labelled" \
117
+ --wandb_project "distil-whisper-labelling" \
118
+ --per_device_eval_batch_size 64 \
119
+ --dtype "bfloat16" \
120
+ --attn_implementation "sdpa" \
121
+ --logging_steps 500 \
122
+ --max_label_length 256 \
123
+ --concatenate_audio \
124
+ --preprocessing_batch_size 500 \
125
+ --preprocessing_num_workers 8 \
126
+ --dataloader_num_workers 8 \
127
+ --report_to "wandb" \
128
+ --language "hi" \
129
+ --task "transcribe" \
130
+ --return_timestamps \
131
+ --streaming False \
132
+ --generation_num_beams 1 \
133
+ --push_to_hub
134
+ ```
135
+
136
+ On an 80 GB A100 GPU, the following script takes approximately 5 minutes to concatenate and pre-process the 20 hours of
137
+ audio data, and a further 10 minutes to transcribe the pseudo-labels. The pseudo-labelled dataset corresponding to this
138
+ script is available on the Hugging Face Hub under [sanchit-gandhi/common_voice_16_1_hi_pseudo_labelled](https://huggingface.co/datasets/sanchit-gandhi/common_voice_16_1_hi_pseudo_labelled).
139
+ The WER of the pre-trained Whisper large-v3 model is 17.2% on the test split. We will compare the performance of our distilled model against this number.
140
+
141
+ There are two noteworthy arguments that configure the dataset concatenation (or packing) process:
142
+ 1. `concatenate_audio`: whether or not to concatenate (or pack) the audios to 30-second chunks. The latest Distil-Whisper model, [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3#differences-with-distil-large-v2), highlights the WER improvements obtained using the sequential long-form decoding algorithm when concatenated audios are used. Concatenating audios to 30-seconds also improves the throughput during training, since the amount of zero-padding on the audio inputs is minimised. Hence, it is highly recommended to set `--concatenate_audio=True`.
143
+ 2. `preprocessing_batch_size`: the batch size to use when concatenating (or packing) the audios. Using a larger batch size results in a greater portion of audio samples being packed to 30-seconds, at the expense of higher memory consumption. If you exceed your system's RAM when performing the concatenation operation, reduce the `preprocessing_batch_size` by a factor of 2 to 250 or even 125.
144
+ 3. `preprocessing_num_workers`: the number of multiprocessing workers to use when concatenating the audios. Using more workers will result in faster pre-processing, at the expense of higher memory consumption. Ensure you do not exceed the maximum number of CPUs on your device.
145
+
146
+ In addition, the following arguments configure the inference of the Whisper model:
147
+ 1. `language`: explicitly setting the language token during inference substantially improves the generation performance of the Whisper model, since the model is forced always to predict in the given language. We recommend you set the language to the language you wish to distil the Whisper model on. The only exception is when distilling an English-only model (i.e. where the model id is appended with an `.en`, e.g. `small.en`), the language argument should be set to None, since there is no language token used during training/inference.
148
+ 2. `return_timestamps`: whether or not to predict timestamps in the pseudo-labels. Timestamp prediction is required should you want your distilled model to be able to predict timestamps at inference time (e.g. for the original OpenAI long-form transcription algorithm). However, the pseudo-labels are marginally less accurate than not using timestamps. We recommend pseudo-labelling **with** timestamps to ensure the distilled model is as general as possible.
149
+ 3. `attn_implementation`: which attention implementation to use for inference. Set to `sdpa` for [PyTorch SDPA](https://huggingface.co/docs/transformers/v4.35.2/en/perf_infer_gpu_one#bettertransformer), or `flash_attn_2` if your hardware supports Flash Attention 2 and you have the [package installed](https://github.com/Dao-AILab/flash-attention).
150
+ 4. `streaming`: whether or not to use Datasets' streaming mode. If enabled, the audio data will be streamed from the Hugging Face Hub with no disk space requirements. However, the user is then responsible for adding the pseudo-labels to the dataset script in a follow-up step (see [Using Streaming Mode](#TODO)). If set to `False`, the audio data will be downloaded and pre-processed offline. At the end of pseudo-labelling, the pseudo-labels will be automatically appended to the original dataset, meaning the dataset is ready to be used for the subsequent training step without any additional steps.
151
+ 5. `generation_num_beams`: how many beams to use while decoding. In practice, we found the distilled model to perform comparably when the data was pseudo-labelled with `generation_num_beams=1` (greedy) or `generation_num_beams>1` (beam). This is likely because the WER filter compensates for the lower quality pseudo-labels obtained using greedy search. However, using `generation_num_beams=1` gives substantially faster inference time for the pseudo-labelling step, and so we recommend this configuration.
152
+
153
+ Should you have your own audio dataset, you can first [convert it](https://huggingface.co/docs/datasets/audio_dataset) to
154
+ Hugging Face Datasets format and push it to the Hugging Face Hub. You can then pseudo-label it using the script above,
155
+ replacing the `--dataset_name` with the name of your dataset on the Hub.
156
+
157
+ Otherwise, you may wish to use an open-source dataset already available on the Hugging Face Hub. We provide a summary of
158
+ the three most popular multilingual datasets in the table below. For more details, refer to the blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#multilingual-speech-recognition).
159
+
160
+ | Dataset | Languages | Domain | Speaking Style | License | Text Column | ID Column |
161
+ |-----------------------------------------------------------------------------------------------|-----------|---------------------------------------|----------------|-----------|---------------------|--------------|
162
+ | [Multilingual LibriSpeech](https://huggingface.co/datasets/facebook/multilingual_librispeech) | 6 | Audiobooks | Narrated | CC-BY-4.0 | `"text"` | `"id"` |
163
+ | [Common Voice 16](https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1) | 120 | Wikipedia text & crowd-sourced speech | Narrated | CC0-1.0 | `"sentence"` | `"path"` |
164
+ | [VoxPopuli](https://huggingface.co/datasets/facebook/voxpopuli) | 15 | European Parliament recordings | Spontaneous | CC0 | `"normalized_text"` | `"audio_id"` |
165
+
166
+ To achieve *robustness* to different distributions of audio data, it is recommended to train on multiple datasets where possible.
167
+ For example, the above three datasets all have splits for the German language. Thus, if distilling a Whisper model for German,
168
+ it would be wise to use a combination of the three datasets during training, in order to cover at least three distinct domains
169
+ (audiobooks, crowd-sourced speech, parliament recordings). You may wish to use a combination of open-source datasets, or
170
+ a combination of open-source and individually owned datasets to cover multiple distributions and domains.
171
+
172
+ ## 2. Initialisation
173
+
174
+ The script [`create_student_model.py`](create_student_model.py) can be used to initialise a small student model
175
+ from a large teacher model. When initialising a student model with fewer layers than the teacher model, the student is
176
+ initialised by copying maximally spaced layers from the teacher, as per the [DistilBart](https://arxiv.org/abs/2010.13002)
177
+ recommendations.
178
+
179
+ First, we need to create a model repository on the Hugging Face Hub. This repository will contain all the required files
180
+ to reproduce the training run, alongside model weights, training logs and a README.md card. You can either create a model
181
+ repository directly on the Hugging Face Hub using the link: https://huggingface.co/new. Or, via the CLI, as we'll show here.
182
+
183
+ Let's pick a name for our distilled model: `distil-whisper-large-v3-hi`. We can run the following command to create a repository under this name:
184
+
185
+ ```bash
186
+ huggingface-cli repo create distil-whisper-large-v3-hi
187
+ ```
188
+
189
+ We can now see the model on the Hub, e.g. under https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi
190
+
191
+ Let's clone the repository so that we can place our training script and model weights inside:
192
+
193
+ ```bash
194
+ git lfs install
195
+ git clone https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi
196
+ ```
197
+
198
+ Be sure to change the repo address to `https://huggingface.co/<your-user-name>/<your-repo-name>`
199
+
200
+ We can now copy the relevant training scrips to the repository:
201
+ ```bash
202
+ cd distil-whisper-large-v3-hi
203
+
204
+ cp ../distil-whisper/training/create_student_model.py .
205
+ cp ../distil-whisper/training/run_distillation.py .
206
+ ```
207
+
208
+ The following command demonstrates how to initialise a student model from the Whisper [large-v3](https://huggingface.co/openai/whisper-large-v3)
209
+ checkpoint, with all 32 encoder layer and 2 decoder layers. The 2 student decoder layers are copied from teacher layers
210
+ 1 and 32 respectively, as the maximally spaced layers:
211
+
212
+ ```bash
213
+ #!/usr/bin/env bash
214
+
215
+ python create_student_model.py \
216
+ --teacher_checkpoint "openai/whisper-large-v3" \
217
+ --encoder_layers 32 \
218
+ --decoder_layers 2 \
219
+ --save_dir "./distil-large-v3-init"
220
+ ```
221
+
222
+ The initialised model will be saved to the sub-directory `distil-large-v3-init` in our model repository.
223
+
224
+ ## 3. Training
225
+
226
+ The script [`run_distillation.py`](run_distillation.py) is an end-to-end script for loading multiple
227
+ datasets, a student model, a teacher model, and performing teacher-student distillation. It uses the loss formulation
228
+ from the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430), which is a weighted sum of the cross-entropy and
229
+ KL-divergence loss terms.
230
+
231
+ The following command takes the Common Voice dataset that was pseudo-labelled in the first stage and trains the
232
+ 2-layer decoder model intialised in the previous step. We pass the local path to the pseudo-labelled Common Voice dataset
233
+ (`../common_voice_16_1_hi_pseudo_labelled`), which you can change to the path where your local pseudo-labelled dataset is
234
+ saved.
235
+
236
+ In this example, we will combine the train and validation splits to give our training set, and evaluate on the test split
237
+ only. This is purely to demonstrate how to combine multiple pseudo-labelled datasets for training, rather than recommended
238
+ advice for defining train/validation splits. We advise that you train on the train splits of your dataset, evaluate and
239
+ tune hyper-parameters on the validation split, and only test the final checkpoint on the test split. Note how multiple
240
+ training datasets and splits can be loaded by separating the dataset arguments by `+` symbols. Thus, the script generalises
241
+ to any number of training datasets.
242
+
243
+ ```bash
244
+ #!/usr/bin/env bash
245
+
246
+ accelerate launch run_distillation.py \
247
+ --model_name_or_path "./distil-large-v3-init" \
248
+ --teacher_model_name_or_path "openai/whisper-large-v3" \
249
+ --train_dataset_name "../common_voice_16_1_hi_pseudo_labelled+../common_voice_16_1_hi_pseudo_labelled" \
250
+ --train_split_name "train+validation" \
251
+ --text_column_name "sentence+sentence" \
252
+ --train_dataset_samples "7+4" \
253
+ --eval_dataset_name "../common_voice_16_1_hi_pseudo_labelled" \
254
+ --eval_split_name "test" \
255
+ --eval_text_column_name "sentence" \
256
+ --eval_steps 1000 \
257
+ --save_steps 1000 \
258
+ --warmup_steps 50 \
259
+ --learning_rate 0.0001 \
260
+ --lr_scheduler_type "constant_with_warmup" \
261
+ --timestamp_probability 0.2 \
262
+ --condition_on_prev_probability 0.2 \
263
+ --language "hi" \
264
+ --task "transcribe" \
265
+ --logging_steps 25 \
266
+ --save_total_limit 1 \
267
+ --max_steps 5000 \
268
+ --wer_threshold 20 \
269
+ --per_device_train_batch_size 32 \
270
+ --per_device_eval_batch_size 32 \
271
+ --dataloader_num_workers 8 \
272
+ --preprocessing_num_workers 8 \
273
+ --ddp_timeout 7200 \
274
+ --dtype "bfloat16" \
275
+ --attn_implementation "sdpa" \
276
+ --output_dir "./" \
277
+ --do_train \
278
+ --do_eval \
279
+ --gradient_checkpointing \
280
+ --overwrite_output_dir \
281
+ --predict_with_generate \
282
+ --freeze_encoder \
283
+ --freeze_embed_positions \
284
+ --streaming False \
285
+ --push_to_hub
286
+
287
+ ```
288
+
289
+ The above training script will take approximately 3 hours to complete on an 80 GB A100 GPU and yield a final WER of 76%.
290
+ While the generations are starting to take form, there is still a 59% WER gap to the teacher model. This is hardly
291
+ surprising give we only have 15 hours of un-filtered data, and closer to just 1.5 hours with data filtering.
292
+ As mentioned above, using upwards of 1000 hours of data and training for 10k steps will likely yield
293
+ more competitive performance. For the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430), we trained on 21k hours
294
+ of audio data for 80k steps. We found that upwards of 13k hours of audio data was required to reach convergence on English
295
+ ASR (see Section 9.2 of the [paper](https://arxiv.org/abs/2311.00430)), so the more data you have, the better!
296
+
297
+ Scaling to multiple GPUs using [distributed data parallelism (DDP)](https://pytorch.org/tutorials/beginner/ddp_series_theory.html)
298
+ is trivial: simply run `accelerate config` and select the multi-GPU option, specifying the IDs of the GPUs you wish to use. The
299
+ above script can then be run using DDP with no code changes.
300
+
301
+ Training logs will be reported to TensorBoard and WandB, provided the relevant packages are available. An example of a
302
+ saved checkpoint pushed to the Hugging Face Hub can be found here: [sanchit-gandhi/distil-whisper-large-v3-hi](https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi).
303
+
304
+ There are a few noteworthy data arguments:
305
+ 1. `train_dataset_samples`: defines the number of training samples in each dataset. Used to calculate the sampling probabilities in the dataloader. A good starting point is setting the samples to the number of hours of audio data in each split. A more refined strategy is setting it to the number of training samples in each split, however this might require downloading the dataset offline to compute these statistics.
306
+ 2. `wer_threshold`: sets the WER threshold between the normalised pseudo-labels and normalised ground truth labels. Any samples with WER > `wer_threshold` are discarded from the training data. This is beneficial to avoid training the student model on pseudo-labels where Whisper hallucinated or got the predictions grossly wrong. In our English distillation experiments, we found a WER threshold of 10% provides the optimal trade-off between ensuring high-quality transcriptions, and not filtering unnecessary amounts of training data. For multilingual distillation, the threshold should be set in accordance with the WER achieved by the pre-trained model on the test set.
307
+ 3. `streaming`: whether or not to use Datasets' streaming mode. Recommended for large datasets, where the audio data can be streamed from the Hugging Face Hub with no disk space requirements.
308
+ 4. `timestamp_probability`: the per-sample probability for retaining timestamp tokens in the labels (should they contain them). Retaining some portion of timestamp tokens in the training data is required to ensure the distilled model can predict timestamps at inference time. In our experiments, we found that training on timestamps with high-probability hurts the distilled model's transcription performance. Thus, we recommend setting this to a value below 0.5. Typically, a value of 0.2 works well, giving good transcription and timestamp performance.
309
+ 5. `condition_on_prev_probability`: the per-sample probability for conditioning on previous labels. Conditioning on previous tokens is required to ensure the distilled model can be used with the "sequential" long-form transcription algorithm at inference time. We did not experiment with this parameter, but found values around 0.2 to provide adequate performance. OpenAI pre-trained Whisper on with a 50% probability for conditioning on previous tokens. Thus, you might wish to try higher values.
310
+
311
+ As well as a few noteworthy model arguments that can be configured to give optimal training performance:
312
+ 1. `freeze_encoder`: whether to freeze the entire encoder of the student model during training. Beneficial when the student encoder is copied exactly from the teacher encoder. In this case, the encoder hidden-states from the teacher model are re-used for the student model. Stopping the gradient computation through the encoder and sharing the encoder hidden-states provides a significant memory saving, and can enable up to 2x batch sizes.
313
+ 2. `freeze_embed_positions`: whether to freeze the student model's decoder positional embeddings. Using the same embed positions as the teacher model, which is designed to handle context lengths up to 448 tokens, helps the student model retain its input id representation up to the full max input length.
314
+ 3. `dtype`: data type (dtype) in which the model computation should be performed. Note that this only controls the dtype of the computations (forward and backward pass), and not the dtype of the parameters or optimiser states.
315
+
316
+ And finally, a few noteworthy training arguments:
317
+ 1. `max_steps`: defines the total number of optimisation steps (forward + backward pass) during training. To reach convergence, you should use a dataset of at least 1k hours and train for a minimum of 50k steps.
318
+ 2. `lr_scheduler_stype`: defines the learning rate schedule, one of `constant_with_warmup` or `linear`. When experimenting with a training set-up or training for very few steps (< 5k), using `constant_with_warmup` is typically beneficial, since the learning rate remains high over the short training run. When performing long training runs (> 5k), using a `linear` schedule generally results in superior downstream performance of the distilled model.
319
+
320
+ TODO:
321
+ - [ ] Template for model cards
322
+
323
+ ## 4. Evaluation
324
+
325
+ There are four types of evaluation performed in Distil-Whisper:
326
+ 1. Short form: evaluation on audio samples less than 30s in duration. Examples include typical ASR test sets, such as the LibriSpeech validation set.
327
+ 2. Sequential long form: evaluation on audio samples longer than 30s in duration using the original "sequential" long-form algorithm. Examples include entire TED talks or earnings calls.
328
+ 3. Chunked long form: evaluation on audio samples longer than 30s in duration using the Transformers "chunked" long-form algorithm.
329
+ 4. Speculative decoding: evaluation on audio samples less than 30s in duration, where a faster, distilled model is used as the assistant to a slower, teacher model.
330
+
331
+ All four forms of evaluation are performed using the script [`run_eval.py`](run_eval.py). Unlike the pseudo-labelling
332
+ and training scripts, the evaluation script assumes that only one GPU accelerator is used. We can copy the corresponding
333
+ evaluation script to the model repository using the following command:
334
+
335
+ ```bash
336
+ cp ../distil-whisper/training/run_eval.py .
337
+ ```
338
+
339
+ Models are assessed jointly using:
340
+ 1. The *word-error rate (WER)* metric: measures the numer of substitution, deletion and insertion errors relative to the total number of words. A lower WER indicates a more accurate model.
341
+ 2. The *inverse real-time factor (RTFx)* metric: measures the ratio of `audio input time : model compute time`. A higher RTFx indicates a faster model.
342
+
343
+ In all cases, it is particularly important to evaluate the final model on data that is *out-of-distribution (OOD)* with
344
+ the training data. Evaluating on OOD data provides insight as to how well the distilled model is likely to generalise to
345
+ different audio distributions at inference time. In our example, the Common Voice test set is *in-distribution (ID)*
346
+ with our training data, since it is taken from the same distribution as the Common Voice training set. Whereas the FLEURS
347
+ test set is OOD, since it is not used as part of the training set.
348
+
349
+ ### Short Form
350
+
351
+ The script [`run_eval.py`](run_eval.py) can be used to evaluate a trained student model over multiple short-form
352
+ validation sets. The following example demonstrates how to evaluate the student model trained in the previous step on
353
+ the Common Voice `test` set (ID) and also the FLEURS `test` set (OOD). Again, it leverages streaming mode to bypass
354
+ the need to download the data offline:
355
+
356
+ ```bash
357
+ #!/usr/bin/env bash
358
+
359
+ python run_eval.py \
360
+ --model_name_or_path "./" \
361
+ --dataset_name "../common_voice_16_1_hi_pseudo_labelled+google/fleurs" \
362
+ --dataset_config_name "default+hi_in" \
363
+ --dataset_split_name "test+test" \
364
+ --text_column_name "sentence+transcription" \
365
+ --batch_size 16 \
366
+ --dtype "bfloat16" \
367
+ --generation_max_length 256 \
368
+ --language "hi" \
369
+ --attn_implementation "sdpa" \
370
+ --streaming
371
+
372
+ ```
373
+
374
+ The student model achieves an average WER of TODO% with an RTFx of TODO for a batch size of 16. We can easily adapt the above
375
+ script to evaluate the teacher model, simply by switching the `model_name_or_path` to `openai/whisper-large-v3`, which
376
+ achieves an average WER of TODO% with an RTFx of TODO. Therefore, for a batch size of 16, the student model is a factor of TODO
377
+ times faster than the teacher. The WER gap can be closed by training on more data (at least 1k hours) for more training
378
+ steps (at least 50k).
379
+
380
+ ### Sequential Long Form
381
+
382
+ The original Whisper paper presents a long-form transcription algorithm that sequentially transcribes 30-second segments
383
+ of audio and shifts the sliding window according to the timestamps predicted by the model. This style of sequential
384
+ inference is performed directly using the [`.generate`](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate)
385
+ method in Transformers.
386
+
387
+ The script [`run_eval.py`](run_eval.py) can be used to evaluate the trained student model on an arbitrary number of
388
+ long-form evaluation sets using the sequential algorithm. Since we don't have a long-form validation set for Hindi to hand,
389
+ in this example we'll evaluate the official Distil-Whisper model [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3)
390
+ on the TED-LIUM validation set:
391
+
392
+ ```bash
393
+ #!/usr/bin/env bash
394
+
395
+ accelerate launch run_eval.py \
396
+ --model_name_or_path "distil-whisper/distil-large-v3" \
397
+ --dataset_name "distil-whisper/tedlium-long-form" \
398
+ --dataset_config_name "default" \
399
+ --dataset_split_name "validation" \
400
+ --text_column_name "text" \
401
+ --batch_size 16 \
402
+ --dtype "bfloat16" \
403
+ --generation_max_length 256 \
404
+ --language "en" \
405
+ --attn_implementation "sdpa" \
406
+ --streaming
407
+
408
+ ```
409
+
410
+ ### Chunked Long Form
411
+
412
+ Chunked long form evaluation runs on the premise that a single long audio file can be *chunked* into smaller segments and
413
+ inferred in parallel. The resulting transcriptions are then joined at the boundaries to give the final text prediction.
414
+ A small overlap (or *stride*) is used between adjacent segments to ensure a continuous transcription across chunks.
415
+
416
+ This style of chunked inference is performed using the [`pipeline`](https://huggingface.co/docs/transformers/main_classes/pipelines)
417
+ class, which provides a wrapper around the [`.generate`](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate)
418
+ function for long-form inference.
419
+
420
+ The script [`run_eval.py`](run_eval.py) can be used to evaluate the trained student model on an arbitrary number of
421
+ long-form evaluation sets using the pipeline class. Again, in this example we'll evaluate distil-large-v3 on the
422
+ TED-LIUM validation set:
423
+
424
+ ```bash
425
+ #!/usr/bin/env bash
426
+
427
+ python run_eval.py \
428
+ --model_name_or_path "openai/whisper-large-v3" \
429
+ --dataset_name "distil-whisper/tedlium-long-form" \
430
+ --dataset_config_name "default" \
431
+ --dataset_split_name "validation" \
432
+ --text_column_name "text" \
433
+ --use_pipeline \
434
+ --chunk_length_s 25.0 \
435
+ --language "en" \
436
+ --return_timestamps \
437
+ --dtype "bfloat16" \
438
+ --streaming
439
+
440
+ ```
441
+
442
+ The argument `chunk_length_s` controls the length of the chunked audio samples. It should be set to match the typical
443
+ length of audio the student model was trained on. If unsure about what value of `chunk_length_s` is optimal for your case,
444
+ it is recommended to run a *sweep* over all possible values. A template script for running a [WandB sweep](https://docs.wandb.ai/guides/sweeps)
445
+ can be found under [`run_chunk_length_s_sweep.yaml`](flax/long_form_transcription_scripts/run_chunk_length_s_sweep.yaml).
446
+
447
+ ### Speculative Decoding
448
+
449
+ Speculative decoding, or assisted generation, relies on the premise that a faster, assistant model can be used to speed-up
450
+ the generation of a slower, assistant model. Speculative decoding mathematically ensures that exactly the same outputs as
451
+ Whisper are obtained, while being ~2 times faster. This makes it the perfect drop-in replacement for existing Whisper
452
+ pipelines, since exactly the same outputs are guaranteed.
453
+
454
+ Distil-Whisper checkpoints can be designed to be efficient assistant models to Whisper for speculative decoding. More precisely,
455
+ by freezing the encoder during training, the distilled model can share the same encoder weights as Whisper during inference, since
456
+ the encoder weights are un-changed. In doing so, only the distilled 2-layer decoder has to be loaded in addition to the
457
+ original Whisper model, which is approximately an 8% increase to the total parameter count, with up to 2x faster inference
458
+ for low batch sizes. For more details on speculative decoding, the reader is advised to refer to the following blog post:
459
+ [Speculative Decoding for 2x Faster Whisper Inference](https://huggingface.co/blog/whisper-speculative-decoding).
460
+
461
+ In the example below, we use our distilled model as an assistant to the large-v3 teacher model during inference:
462
+
463
+ ```bash
464
+ #!/usr/bin/env bash
465
+
466
+ python run_eval.py \
467
+ --model_name_or_path "openai/whisper-large-v3" \
468
+ --assistant_model_name_or_path "./" \
469
+ --dataset_name "../common_voice_16_1_hi_pseudo_labelled+google/fleurs" \
470
+ --dataset_config_name "default+hi_in" \
471
+ --dataset_split_name "test+test" \
472
+ --text_column_name "sentence+transcription" \
473
+ --batch_size 16 \
474
+ --dtype "bfloat16" \
475
+ --generation_max_length 256 \
476
+ --language "hi" \
477
+ --attn_implementation "sdpa" \
478
+ --streaming
479
+
480
+ ```
481
+
482
+ We see that we achieve a WER of TODO%, the same as what we obtained with the large-v3 model, but with an RTFx of TODO,
483
+ a factor of TODO faster than using the large-v3 model alone. The RTFx value can be improved by training the student on
484
+ more data and for more training steps, since this will improve the number of predicted tokens that match the teacher
485
+ predictions.
486
+
487
+ ## Overview of Training Methods
488
+
489
+ ### 1. Fine-Tuning
490
+
491
+ For fine-tuning, we take the original Whisper checkpoint and train it on one or more datasets using the standard
492
+ cross-entropy loss. As such, there is no involvement from the teacher checkpoint during training, and so the fine-tuned
493
+ model is permitted to *overfit* to the distribution of the training data we provide. This makes it appealing for "low-resource"
494
+ languages where the original Whisper model performs poorly, since we can boost the performance of the model on a single
495
+ language by *overfitting* to that distribution of data. Note that this means the fine-tuned model is prone to loosing
496
+ its robustness to different audio distributions, which is the trade-off with improving performance on a specified dataset.
497
+
498
+ As a rule of thumb, fine-tuning is appropriate for languages where the original Whisper model performs > 20% WER, and we
499
+ have a relatively small quantity of training data available (< 1000 hours). With fine-tuning, we require as little as **10 hours**
500
+ of training data to significantly boost the performance of the Whisper model. For an in-depth guide to fine-tuning Whisper,
501
+ the reader is advised to refer to the blog post: [Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers](https://huggingface.co/blog/fine-tune-whisper).
502
+
503
+ ### 2. Shrink and Fine-Tune
504
+
505
+ Shrink and fine-tune (SFT) is a knowledge distillation (KD) technique in which we first *shrink* the teacher model to a
506
+ smaller student model by copying maximally spaced layers, and then *fine-tune* the student model on the cross-entropy loss
507
+ as described above. Typically, we retain the full encoder from the Whisper model and only shrink the decoder. Retaining
508
+ the entire encoder helps significantly with maintaining Whisper's robustness to different audio distributions (_c.f._
509
+ Section 9.3 of the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)).
510
+
511
+ We can either train the student model on a dataset of (audio, text) pairs as above. Or, we can use the pre-trained
512
+ Whisper model to generate *pseudo-labels* for our audio data, and train on the (audio, pseudo-label) pairs.
513
+
514
+ Pseudo-labels can be used when either:
515
+ 1. The original text transcriptions are normalised (lower-cased or no punctuation): the Whisper generated pseudo-labels contain both punctuation and casing, and so can be used as a substitute for the normalised transcriptions
516
+ 2. The pre-trained Whisper model achieves < 20% WER on the languages: we then know the majority of the pseudo-labels will be accurate enough for us to train on.
517
+
518
+ They are not recommended when both of the following are true:
519
+ 1. The original text is punctuated and cased
520
+ 2. The pre-trained Whisper model achieves > 20% WER on the languages: in this case, we want to overfit to the particular distribution of the language, and so train directly on the original text data
521
+
522
+ To discard inaccurate pseudo-labels during training, we employ a simple WER heuristic to filter our pseudo-labelled
523
+ training data. We first normalise the original text and the pseudo-labelled text using the Whisper normaliser. If the
524
+ WER between the normalised text exceeds a 10% WER threshold, we discard the training sample. Else, we retain it for training.
525
+ Section 9.1 of the Distil-Whisper [paper](https://arxiv.org/abs/2311.00430) demonstrates the importance of using this
526
+ threshold for training.
527
+
528
+ ### 3. KL Divergence
529
+
530
+ In the KL Divergence setting, the student model is initialised by shrinking the teacher as before, and then trained to
531
+ match the predictions of the teacher during training.
532
+
533
+ ### Summary of Methods
534
+
535
+ The following table summarises the two training paradigms: fine-tuning and knowledge distillation (KD). It suggests
536
+ minimum values for the pre-trained WER / training data to achieve reasonable performance:
537
+
538
+ | Method | Pre-Trained WER / % | Training Data / h |
539
+ |-------------|---------------------|-------------------|
540
+ | Fine-tuning | > 20 | < 1000 |
541
+ | KD | < 20 | > 1000 |
542
+
543
+ ## Acknowledgements
544
+
545
+ * OpenAI for the Whisper [model](https://huggingface.co/openai/whisper-large-v3) and [original codebase](https://github.com/openai/whisper)
546
+ * Hugging Face 🤗 [Transformers](https://github.com/huggingface/transformers) for the Whisper model implementation
547
+ * Google's [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) program for Cloud TPU v4s used to train the official Distil-Whisper models
548
+ * The Hugging Face 🤗 cluster for enabling experimentation with the PyTorch scripts
549
+
550
+ ## Citation
551
+
552
+ If you use this code-base, please consider citing the Distil-Whisper paper:
553
+
554
+ ```
555
+ @misc{gandhi2023distilwhisper,
556
+ title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
557
+ author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
558
+ year={2023},
559
+ eprint={2311.00430},
560
+ archivePrefix={arXiv},
561
+ primaryClass={cs.CL}
562
+ }
563
+ ```
added_tokens.json ADDED
@@ -0,0 +1,1611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|0.00|>": 50365,
3
+ "<|0.02|>": 50366,
4
+ "<|0.04|>": 50367,
5
+ "<|0.06|>": 50368,
6
+ "<|0.08|>": 50369,
7
+ "<|0.10|>": 50370,
8
+ "<|0.12|>": 50371,
9
+ "<|0.14|>": 50372,
10
+ "<|0.16|>": 50373,
11
+ "<|0.18|>": 50374,
12
+ "<|0.20|>": 50375,
13
+ "<|0.22|>": 50376,
14
+ "<|0.24|>": 50377,
15
+ "<|0.26|>": 50378,
16
+ "<|0.28|>": 50379,
17
+ "<|0.30|>": 50380,
18
+ "<|0.32|>": 50381,
19
+ "<|0.34|>": 50382,
20
+ "<|0.36|>": 50383,
21
+ "<|0.38|>": 50384,
22
+ "<|0.40|>": 50385,
23
+ "<|0.42|>": 50386,
24
+ "<|0.44|>": 50387,
25
+ "<|0.46|>": 50388,
26
+ "<|0.48|>": 50389,
27
+ "<|0.50|>": 50390,
28
+ "<|0.52|>": 50391,
29
+ "<|0.54|>": 50392,
30
+ "<|0.56|>": 50393,
31
+ "<|0.58|>": 50394,
32
+ "<|0.60|>": 50395,
33
+ "<|0.62|>": 50396,
34
+ "<|0.64|>": 50397,
35
+ "<|0.66|>": 50398,
36
+ "<|0.68|>": 50399,
37
+ "<|0.70|>": 50400,
38
+ "<|0.72|>": 50401,
39
+ "<|0.74|>": 50402,
40
+ "<|0.76|>": 50403,
41
+ "<|0.78|>": 50404,
42
+ "<|0.80|>": 50405,
43
+ "<|0.82|>": 50406,
44
+ "<|0.84|>": 50407,
45
+ "<|0.86|>": 50408,
46
+ "<|0.88|>": 50409,
47
+ "<|0.90|>": 50410,
48
+ "<|0.92|>": 50411,
49
+ "<|0.94|>": 50412,
50
+ "<|0.96|>": 50413,
51
+ "<|0.98|>": 50414,
52
+ "<|1.00|>": 50415,
53
+ "<|1.02|>": 50416,
54
+ "<|1.04|>": 50417,
55
+ "<|1.06|>": 50418,
56
+ "<|1.08|>": 50419,
57
+ "<|1.10|>": 50420,
58
+ "<|1.12|>": 50421,
59
+ "<|1.14|>": 50422,
60
+ "<|1.16|>": 50423,
61
+ "<|1.18|>": 50424,
62
+ "<|1.20|>": 50425,
63
+ "<|1.22|>": 50426,
64
+ "<|1.24|>": 50427,
65
+ "<|1.26|>": 50428,
66
+ "<|1.28|>": 50429,
67
+ "<|1.30|>": 50430,
68
+ "<|1.32|>": 50431,
69
+ "<|1.34|>": 50432,
70
+ "<|1.36|>": 50433,
71
+ "<|1.38|>": 50434,
72
+ "<|1.40|>": 50435,
73
+ "<|1.42|>": 50436,
74
+ "<|1.44|>": 50437,
75
+ "<|1.46|>": 50438,
76
+ "<|1.48|>": 50439,
77
+ "<|1.50|>": 50440,
78
+ "<|1.52|>": 50441,
79
+ "<|1.54|>": 50442,
80
+ "<|1.56|>": 50443,
81
+ "<|1.58|>": 50444,
82
+ "<|1.60|>": 50445,
83
+ "<|1.62|>": 50446,
84
+ "<|1.64|>": 50447,
85
+ "<|1.66|>": 50448,
86
+ "<|1.68|>": 50449,
87
+ "<|1.70|>": 50450,
88
+ "<|1.72|>": 50451,
89
+ "<|1.74|>": 50452,
90
+ "<|1.76|>": 50453,
91
+ "<|1.78|>": 50454,
92
+ "<|1.80|>": 50455,
93
+ "<|1.82|>": 50456,
94
+ "<|1.84|>": 50457,
95
+ "<|1.86|>": 50458,
96
+ "<|1.88|>": 50459,
97
+ "<|1.90|>": 50460,
98
+ "<|1.92|>": 50461,
99
+ "<|1.94|>": 50462,
100
+ "<|1.96|>": 50463,
101
+ "<|1.98|>": 50464,
102
+ "<|10.00|>": 50865,
103
+ "<|10.02|>": 50866,
104
+ "<|10.04|>": 50867,
105
+ "<|10.06|>": 50868,
106
+ "<|10.08|>": 50869,
107
+ "<|10.10|>": 50870,
108
+ "<|10.12|>": 50871,
109
+ "<|10.14|>": 50872,
110
+ "<|10.16|>": 50873,
111
+ "<|10.18|>": 50874,
112
+ "<|10.20|>": 50875,
113
+ "<|10.22|>": 50876,
114
+ "<|10.24|>": 50877,
115
+ "<|10.26|>": 50878,
116
+ "<|10.28|>": 50879,
117
+ "<|10.30|>": 50880,
118
+ "<|10.32|>": 50881,
119
+ "<|10.34|>": 50882,
120
+ "<|10.36|>": 50883,
121
+ "<|10.38|>": 50884,
122
+ "<|10.40|>": 50885,
123
+ "<|10.42|>": 50886,
124
+ "<|10.44|>": 50887,
125
+ "<|10.46|>": 50888,
126
+ "<|10.48|>": 50889,
127
+ "<|10.50|>": 50890,
128
+ "<|10.52|>": 50891,
129
+ "<|10.54|>": 50892,
130
+ "<|10.56|>": 50893,
131
+ "<|10.58|>": 50894,
132
+ "<|10.60|>": 50895,
133
+ "<|10.62|>": 50896,
134
+ "<|10.64|>": 50897,
135
+ "<|10.66|>": 50898,
136
+ "<|10.68|>": 50899,
137
+ "<|10.70|>": 50900,
138
+ "<|10.72|>": 50901,
139
+ "<|10.74|>": 50902,
140
+ "<|10.76|>": 50903,
141
+ "<|10.78|>": 50904,
142
+ "<|10.80|>": 50905,
143
+ "<|10.82|>": 50906,
144
+ "<|10.84|>": 50907,
145
+ "<|10.86|>": 50908,
146
+ "<|10.88|>": 50909,
147
+ "<|10.90|>": 50910,
148
+ "<|10.92|>": 50911,
149
+ "<|10.94|>": 50912,
150
+ "<|10.96|>": 50913,
151
+ "<|10.98|>": 50914,
152
+ "<|11.00|>": 50915,
153
+ "<|11.02|>": 50916,
154
+ "<|11.04|>": 50917,
155
+ "<|11.06|>": 50918,
156
+ "<|11.08|>": 50919,
157
+ "<|11.10|>": 50920,
158
+ "<|11.12|>": 50921,
159
+ "<|11.14|>": 50922,
160
+ "<|11.16|>": 50923,
161
+ "<|11.18|>": 50924,
162
+ "<|11.20|>": 50925,
163
+ "<|11.22|>": 50926,
164
+ "<|11.24|>": 50927,
165
+ "<|11.26|>": 50928,
166
+ "<|11.28|>": 50929,
167
+ "<|11.30|>": 50930,
168
+ "<|11.32|>": 50931,
169
+ "<|11.34|>": 50932,
170
+ "<|11.36|>": 50933,
171
+ "<|11.38|>": 50934,
172
+ "<|11.40|>": 50935,
173
+ "<|11.42|>": 50936,
174
+ "<|11.44|>": 50937,
175
+ "<|11.46|>": 50938,
176
+ "<|11.48|>": 50939,
177
+ "<|11.50|>": 50940,
178
+ "<|11.52|>": 50941,
179
+ "<|11.54|>": 50942,
180
+ "<|11.56|>": 50943,
181
+ "<|11.58|>": 50944,
182
+ "<|11.60|>": 50945,
183
+ "<|11.62|>": 50946,
184
+ "<|11.64|>": 50947,
185
+ "<|11.66|>": 50948,
186
+ "<|11.68|>": 50949,
187
+ "<|11.70|>": 50950,
188
+ "<|11.72|>": 50951,
189
+ "<|11.74|>": 50952,
190
+ "<|11.76|>": 50953,
191
+ "<|11.78|>": 50954,
192
+ "<|11.80|>": 50955,
193
+ "<|11.82|>": 50956,
194
+ "<|11.84|>": 50957,
195
+ "<|11.86|>": 50958,
196
+ "<|11.88|>": 50959,
197
+ "<|11.90|>": 50960,
198
+ "<|11.92|>": 50961,
199
+ "<|11.94|>": 50962,
200
+ "<|11.96|>": 50963,
201
+ "<|11.98|>": 50964,
202
+ "<|12.00|>": 50965,
203
+ "<|12.02|>": 50966,
204
+ "<|12.04|>": 50967,
205
+ "<|12.06|>": 50968,
206
+ "<|12.08|>": 50969,
207
+ "<|12.10|>": 50970,
208
+ "<|12.12|>": 50971,
209
+ "<|12.14|>": 50972,
210
+ "<|12.16|>": 50973,
211
+ "<|12.18|>": 50974,
212
+ "<|12.20|>": 50975,
213
+ "<|12.22|>": 50976,
214
+ "<|12.24|>": 50977,
215
+ "<|12.26|>": 50978,
216
+ "<|12.28|>": 50979,
217
+ "<|12.30|>": 50980,
218
+ "<|12.32|>": 50981,
219
+ "<|12.34|>": 50982,
220
+ "<|12.36|>": 50983,
221
+ "<|12.38|>": 50984,
222
+ "<|12.40|>": 50985,
223
+ "<|12.42|>": 50986,
224
+ "<|12.44|>": 50987,
225
+ "<|12.46|>": 50988,
226
+ "<|12.48|>": 50989,
227
+ "<|12.50|>": 50990,
228
+ "<|12.52|>": 50991,
229
+ "<|12.54|>": 50992,
230
+ "<|12.56|>": 50993,
231
+ "<|12.58|>": 50994,
232
+ "<|12.60|>": 50995,
233
+ "<|12.62|>": 50996,
234
+ "<|12.64|>": 50997,
235
+ "<|12.66|>": 50998,
236
+ "<|12.68|>": 50999,
237
+ "<|12.70|>": 51000,
238
+ "<|12.72|>": 51001,
239
+ "<|12.74|>": 51002,
240
+ "<|12.76|>": 51003,
241
+ "<|12.78|>": 51004,
242
+ "<|12.80|>": 51005,
243
+ "<|12.82|>": 51006,
244
+ "<|12.84|>": 51007,
245
+ "<|12.86|>": 51008,
246
+ "<|12.88|>": 51009,
247
+ "<|12.90|>": 51010,
248
+ "<|12.92|>": 51011,
249
+ "<|12.94|>": 51012,
250
+ "<|12.96|>": 51013,
251
+ "<|12.98|>": 51014,
252
+ "<|13.00|>": 51015,
253
+ "<|13.02|>": 51016,
254
+ "<|13.04|>": 51017,
255
+ "<|13.06|>": 51018,
256
+ "<|13.08|>": 51019,
257
+ "<|13.10|>": 51020,
258
+ "<|13.12|>": 51021,
259
+ "<|13.14|>": 51022,
260
+ "<|13.16|>": 51023,
261
+ "<|13.18|>": 51024,
262
+ "<|13.20|>": 51025,
263
+ "<|13.22|>": 51026,
264
+ "<|13.24|>": 51027,
265
+ "<|13.26|>": 51028,
266
+ "<|13.28|>": 51029,
267
+ "<|13.30|>": 51030,
268
+ "<|13.32|>": 51031,
269
+ "<|13.34|>": 51032,
270
+ "<|13.36|>": 51033,
271
+ "<|13.38|>": 51034,
272
+ "<|13.40|>": 51035,
273
+ "<|13.42|>": 51036,
274
+ "<|13.44|>": 51037,
275
+ "<|13.46|>": 51038,
276
+ "<|13.48|>": 51039,
277
+ "<|13.50|>": 51040,
278
+ "<|13.52|>": 51041,
279
+ "<|13.54|>": 51042,
280
+ "<|13.56|>": 51043,
281
+ "<|13.58|>": 51044,
282
+ "<|13.60|>": 51045,
283
+ "<|13.62|>": 51046,
284
+ "<|13.64|>": 51047,
285
+ "<|13.66|>": 51048,
286
+ "<|13.68|>": 51049,
287
+ "<|13.70|>": 51050,
288
+ "<|13.72|>": 51051,
289
+ "<|13.74|>": 51052,
290
+ "<|13.76|>": 51053,
291
+ "<|13.78|>": 51054,
292
+ "<|13.80|>": 51055,
293
+ "<|13.82|>": 51056,
294
+ "<|13.84|>": 51057,
295
+ "<|13.86|>": 51058,
296
+ "<|13.88|>": 51059,
297
+ "<|13.90|>": 51060,
298
+ "<|13.92|>": 51061,
299
+ "<|13.94|>": 51062,
300
+ "<|13.96|>": 51063,
301
+ "<|13.98|>": 51064,
302
+ "<|14.00|>": 51065,
303
+ "<|14.02|>": 51066,
304
+ "<|14.04|>": 51067,
305
+ "<|14.06|>": 51068,
306
+ "<|14.08|>": 51069,
307
+ "<|14.10|>": 51070,
308
+ "<|14.12|>": 51071,
309
+ "<|14.14|>": 51072,
310
+ "<|14.16|>": 51073,
311
+ "<|14.18|>": 51074,
312
+ "<|14.20|>": 51075,
313
+ "<|14.22|>": 51076,
314
+ "<|14.24|>": 51077,
315
+ "<|14.26|>": 51078,
316
+ "<|14.28|>": 51079,
317
+ "<|14.30|>": 51080,
318
+ "<|14.32|>": 51081,
319
+ "<|14.34|>": 51082,
320
+ "<|14.36|>": 51083,
321
+ "<|14.38|>": 51084,
322
+ "<|14.40|>": 51085,
323
+ "<|14.42|>": 51086,
324
+ "<|14.44|>": 51087,
325
+ "<|14.46|>": 51088,
326
+ "<|14.48|>": 51089,
327
+ "<|14.50|>": 51090,
328
+ "<|14.52|>": 51091,
329
+ "<|14.54|>": 51092,
330
+ "<|14.56|>": 51093,
331
+ "<|14.58|>": 51094,
332
+ "<|14.60|>": 51095,
333
+ "<|14.62|>": 51096,
334
+ "<|14.64|>": 51097,
335
+ "<|14.66|>": 51098,
336
+ "<|14.68|>": 51099,
337
+ "<|14.70|>": 51100,
338
+ "<|14.72|>": 51101,
339
+ "<|14.74|>": 51102,
340
+ "<|14.76|>": 51103,
341
+ "<|14.78|>": 51104,
342
+ "<|14.80|>": 51105,
343
+ "<|14.82|>": 51106,
344
+ "<|14.84|>": 51107,
345
+ "<|14.86|>": 51108,
346
+ "<|14.88|>": 51109,
347
+ "<|14.90|>": 51110,
348
+ "<|14.92|>": 51111,
349
+ "<|14.94|>": 51112,
350
+ "<|14.96|>": 51113,
351
+ "<|14.98|>": 51114,
352
+ "<|15.00|>": 51115,
353
+ "<|15.02|>": 51116,
354
+ "<|15.04|>": 51117,
355
+ "<|15.06|>": 51118,
356
+ "<|15.08|>": 51119,
357
+ "<|15.10|>": 51120,
358
+ "<|15.12|>": 51121,
359
+ "<|15.14|>": 51122,
360
+ "<|15.16|>": 51123,
361
+ "<|15.18|>": 51124,
362
+ "<|15.20|>": 51125,
363
+ "<|15.22|>": 51126,
364
+ "<|15.24|>": 51127,
365
+ "<|15.26|>": 51128,
366
+ "<|15.28|>": 51129,
367
+ "<|15.30|>": 51130,
368
+ "<|15.32|>": 51131,
369
+ "<|15.34|>": 51132,
370
+ "<|15.36|>": 51133,
371
+ "<|15.38|>": 51134,
372
+ "<|15.40|>": 51135,
373
+ "<|15.42|>": 51136,
374
+ "<|15.44|>": 51137,
375
+ "<|15.46|>": 51138,
376
+ "<|15.48|>": 51139,
377
+ "<|15.50|>": 51140,
378
+ "<|15.52|>": 51141,
379
+ "<|15.54|>": 51142,
380
+ "<|15.56|>": 51143,
381
+ "<|15.58|>": 51144,
382
+ "<|15.60|>": 51145,
383
+ "<|15.62|>": 51146,
384
+ "<|15.64|>": 51147,
385
+ "<|15.66|>": 51148,
386
+ "<|15.68|>": 51149,
387
+ "<|15.70|>": 51150,
388
+ "<|15.72|>": 51151,
389
+ "<|15.74|>": 51152,
390
+ "<|15.76|>": 51153,
391
+ "<|15.78|>": 51154,
392
+ "<|15.80|>": 51155,
393
+ "<|15.82|>": 51156,
394
+ "<|15.84|>": 51157,
395
+ "<|15.86|>": 51158,
396
+ "<|15.88|>": 51159,
397
+ "<|15.90|>": 51160,
398
+ "<|15.92|>": 51161,
399
+ "<|15.94|>": 51162,
400
+ "<|15.96|>": 51163,
401
+ "<|15.98|>": 51164,
402
+ "<|16.00|>": 51165,
403
+ "<|16.02|>": 51166,
404
+ "<|16.04|>": 51167,
405
+ "<|16.06|>": 51168,
406
+ "<|16.08|>": 51169,
407
+ "<|16.10|>": 51170,
408
+ "<|16.12|>": 51171,
409
+ "<|16.14|>": 51172,
410
+ "<|16.16|>": 51173,
411
+ "<|16.18|>": 51174,
412
+ "<|16.20|>": 51175,
413
+ "<|16.22|>": 51176,
414
+ "<|16.24|>": 51177,
415
+ "<|16.26|>": 51178,
416
+ "<|16.28|>": 51179,
417
+ "<|16.30|>": 51180,
418
+ "<|16.32|>": 51181,
419
+ "<|16.34|>": 51182,
420
+ "<|16.36|>": 51183,
421
+ "<|16.38|>": 51184,
422
+ "<|16.40|>": 51185,
423
+ "<|16.42|>": 51186,
424
+ "<|16.44|>": 51187,
425
+ "<|16.46|>": 51188,
426
+ "<|16.48|>": 51189,
427
+ "<|16.50|>": 51190,
428
+ "<|16.52|>": 51191,
429
+ "<|16.54|>": 51192,
430
+ "<|16.56|>": 51193,
431
+ "<|16.58|>": 51194,
432
+ "<|16.60|>": 51195,
433
+ "<|16.62|>": 51196,
434
+ "<|16.64|>": 51197,
435
+ "<|16.66|>": 51198,
436
+ "<|16.68|>": 51199,
437
+ "<|16.70|>": 51200,
438
+ "<|16.72|>": 51201,
439
+ "<|16.74|>": 51202,
440
+ "<|16.76|>": 51203,
441
+ "<|16.78|>": 51204,
442
+ "<|16.80|>": 51205,
443
+ "<|16.82|>": 51206,
444
+ "<|16.84|>": 51207,
445
+ "<|16.86|>": 51208,
446
+ "<|16.88|>": 51209,
447
+ "<|16.90|>": 51210,
448
+ "<|16.92|>": 51211,
449
+ "<|16.94|>": 51212,
450
+ "<|16.96|>": 51213,
451
+ "<|16.98|>": 51214,
452
+ "<|17.00|>": 51215,
453
+ "<|17.02|>": 51216,
454
+ "<|17.04|>": 51217,
455
+ "<|17.06|>": 51218,
456
+ "<|17.08|>": 51219,
457
+ "<|17.10|>": 51220,
458
+ "<|17.12|>": 51221,
459
+ "<|17.14|>": 51222,
460
+ "<|17.16|>": 51223,
461
+ "<|17.18|>": 51224,
462
+ "<|17.20|>": 51225,
463
+ "<|17.22|>": 51226,
464
+ "<|17.24|>": 51227,
465
+ "<|17.26|>": 51228,
466
+ "<|17.28|>": 51229,
467
+ "<|17.30|>": 51230,
468
+ "<|17.32|>": 51231,
469
+ "<|17.34|>": 51232,
470
+ "<|17.36|>": 51233,
471
+ "<|17.38|>": 51234,
472
+ "<|17.40|>": 51235,
473
+ "<|17.42|>": 51236,
474
+ "<|17.44|>": 51237,
475
+ "<|17.46|>": 51238,
476
+ "<|17.48|>": 51239,
477
+ "<|17.50|>": 51240,
478
+ "<|17.52|>": 51241,
479
+ "<|17.54|>": 51242,
480
+ "<|17.56|>": 51243,
481
+ "<|17.58|>": 51244,
482
+ "<|17.60|>": 51245,
483
+ "<|17.62|>": 51246,
484
+ "<|17.64|>": 51247,
485
+ "<|17.66|>": 51248,
486
+ "<|17.68|>": 51249,
487
+ "<|17.70|>": 51250,
488
+ "<|17.72|>": 51251,
489
+ "<|17.74|>": 51252,
490
+ "<|17.76|>": 51253,
491
+ "<|17.78|>": 51254,
492
+ "<|17.80|>": 51255,
493
+ "<|17.82|>": 51256,
494
+ "<|17.84|>": 51257,
495
+ "<|17.86|>": 51258,
496
+ "<|17.88|>": 51259,
497
+ "<|17.90|>": 51260,
498
+ "<|17.92|>": 51261,
499
+ "<|17.94|>": 51262,
500
+ "<|17.96|>": 51263,
501
+ "<|17.98|>": 51264,
502
+ "<|18.00|>": 51265,
503
+ "<|18.02|>": 51266,
504
+ "<|18.04|>": 51267,
505
+ "<|18.06|>": 51268,
506
+ "<|18.08|>": 51269,
507
+ "<|18.10|>": 51270,
508
+ "<|18.12|>": 51271,
509
+ "<|18.14|>": 51272,
510
+ "<|18.16|>": 51273,
511
+ "<|18.18|>": 51274,
512
+ "<|18.20|>": 51275,
513
+ "<|18.22|>": 51276,
514
+ "<|18.24|>": 51277,
515
+ "<|18.26|>": 51278,
516
+ "<|18.28|>": 51279,
517
+ "<|18.30|>": 51280,
518
+ "<|18.32|>": 51281,
519
+ "<|18.34|>": 51282,
520
+ "<|18.36|>": 51283,
521
+ "<|18.38|>": 51284,
522
+ "<|18.40|>": 51285,
523
+ "<|18.42|>": 51286,
524
+ "<|18.44|>": 51287,
525
+ "<|18.46|>": 51288,
526
+ "<|18.48|>": 51289,
527
+ "<|18.50|>": 51290,
528
+ "<|18.52|>": 51291,
529
+ "<|18.54|>": 51292,
530
+ "<|18.56|>": 51293,
531
+ "<|18.58|>": 51294,
532
+ "<|18.60|>": 51295,
533
+ "<|18.62|>": 51296,
534
+ "<|18.64|>": 51297,
535
+ "<|18.66|>": 51298,
536
+ "<|18.68|>": 51299,
537
+ "<|18.70|>": 51300,
538
+ "<|18.72|>": 51301,
539
+ "<|18.74|>": 51302,
540
+ "<|18.76|>": 51303,
541
+ "<|18.78|>": 51304,
542
+ "<|18.80|>": 51305,
543
+ "<|18.82|>": 51306,
544
+ "<|18.84|>": 51307,
545
+ "<|18.86|>": 51308,
546
+ "<|18.88|>": 51309,
547
+ "<|18.90|>": 51310,
548
+ "<|18.92|>": 51311,
549
+ "<|18.94|>": 51312,
550
+ "<|18.96|>": 51313,
551
+ "<|18.98|>": 51314,
552
+ "<|19.00|>": 51315,
553
+ "<|19.02|>": 51316,
554
+ "<|19.04|>": 51317,
555
+ "<|19.06|>": 51318,
556
+ "<|19.08|>": 51319,
557
+ "<|19.10|>": 51320,
558
+ "<|19.12|>": 51321,
559
+ "<|19.14|>": 51322,
560
+ "<|19.16|>": 51323,
561
+ "<|19.18|>": 51324,
562
+ "<|19.20|>": 51325,
563
+ "<|19.22|>": 51326,
564
+ "<|19.24|>": 51327,
565
+ "<|19.26|>": 51328,
566
+ "<|19.28|>": 51329,
567
+ "<|19.30|>": 51330,
568
+ "<|19.32|>": 51331,
569
+ "<|19.34|>": 51332,
570
+ "<|19.36|>": 51333,
571
+ "<|19.38|>": 51334,
572
+ "<|19.40|>": 51335,
573
+ "<|19.42|>": 51336,
574
+ "<|19.44|>": 51337,
575
+ "<|19.46|>": 51338,
576
+ "<|19.48|>": 51339,
577
+ "<|19.50|>": 51340,
578
+ "<|19.52|>": 51341,
579
+ "<|19.54|>": 51342,
580
+ "<|19.56|>": 51343,
581
+ "<|19.58|>": 51344,
582
+ "<|19.60|>": 51345,
583
+ "<|19.62|>": 51346,
584
+ "<|19.64|>": 51347,
585
+ "<|19.66|>": 51348,
586
+ "<|19.68|>": 51349,
587
+ "<|19.70|>": 51350,
588
+ "<|19.72|>": 51351,
589
+ "<|19.74|>": 51352,
590
+ "<|19.76|>": 51353,
591
+ "<|19.78|>": 51354,
592
+ "<|19.80|>": 51355,
593
+ "<|19.82|>": 51356,
594
+ "<|19.84|>": 51357,
595
+ "<|19.86|>": 51358,
596
+ "<|19.88|>": 51359,
597
+ "<|19.90|>": 51360,
598
+ "<|19.92|>": 51361,
599
+ "<|19.94|>": 51362,
600
+ "<|19.96|>": 51363,
601
+ "<|19.98|>": 51364,
602
+ "<|2.00|>": 50465,
603
+ "<|2.02|>": 50466,
604
+ "<|2.04|>": 50467,
605
+ "<|2.06|>": 50468,
606
+ "<|2.08|>": 50469,
607
+ "<|2.10|>": 50470,
608
+ "<|2.12|>": 50471,
609
+ "<|2.14|>": 50472,
610
+ "<|2.16|>": 50473,
611
+ "<|2.18|>": 50474,
612
+ "<|2.20|>": 50475,
613
+ "<|2.22|>": 50476,
614
+ "<|2.24|>": 50477,
615
+ "<|2.26|>": 50478,
616
+ "<|2.28|>": 50479,
617
+ "<|2.30|>": 50480,
618
+ "<|2.32|>": 50481,
619
+ "<|2.34|>": 50482,
620
+ "<|2.36|>": 50483,
621
+ "<|2.38|>": 50484,
622
+ "<|2.40|>": 50485,
623
+ "<|2.42|>": 50486,
624
+ "<|2.44|>": 50487,
625
+ "<|2.46|>": 50488,
626
+ "<|2.48|>": 50489,
627
+ "<|2.50|>": 50490,
628
+ "<|2.52|>": 50491,
629
+ "<|2.54|>": 50492,
630
+ "<|2.56|>": 50493,
631
+ "<|2.58|>": 50494,
632
+ "<|2.60|>": 50495,
633
+ "<|2.62|>": 50496,
634
+ "<|2.64|>": 50497,
635
+ "<|2.66|>": 50498,
636
+ "<|2.68|>": 50499,
637
+ "<|2.70|>": 50500,
638
+ "<|2.72|>": 50501,
639
+ "<|2.74|>": 50502,
640
+ "<|2.76|>": 50503,
641
+ "<|2.78|>": 50504,
642
+ "<|2.80|>": 50505,
643
+ "<|2.82|>": 50506,
644
+ "<|2.84|>": 50507,
645
+ "<|2.86|>": 50508,
646
+ "<|2.88|>": 50509,
647
+ "<|2.90|>": 50510,
648
+ "<|2.92|>": 50511,
649
+ "<|2.94|>": 50512,
650
+ "<|2.96|>": 50513,
651
+ "<|2.98|>": 50514,
652
+ "<|20.00|>": 51365,
653
+ "<|20.02|>": 51366,
654
+ "<|20.04|>": 51367,
655
+ "<|20.06|>": 51368,
656
+ "<|20.08|>": 51369,
657
+ "<|20.10|>": 51370,
658
+ "<|20.12|>": 51371,
659
+ "<|20.14|>": 51372,
660
+ "<|20.16|>": 51373,
661
+ "<|20.18|>": 51374,
662
+ "<|20.20|>": 51375,
663
+ "<|20.22|>": 51376,
664
+ "<|20.24|>": 51377,
665
+ "<|20.26|>": 51378,
666
+ "<|20.28|>": 51379,
667
+ "<|20.30|>": 51380,
668
+ "<|20.32|>": 51381,
669
+ "<|20.34|>": 51382,
670
+ "<|20.36|>": 51383,
671
+ "<|20.38|>": 51384,
672
+ "<|20.40|>": 51385,
673
+ "<|20.42|>": 51386,
674
+ "<|20.44|>": 51387,
675
+ "<|20.46|>": 51388,
676
+ "<|20.48|>": 51389,
677
+ "<|20.50|>": 51390,
678
+ "<|20.52|>": 51391,
679
+ "<|20.54|>": 51392,
680
+ "<|20.56|>": 51393,
681
+ "<|20.58|>": 51394,
682
+ "<|20.60|>": 51395,
683
+ "<|20.62|>": 51396,
684
+ "<|20.64|>": 51397,
685
+ "<|20.66|>": 51398,
686
+ "<|20.68|>": 51399,
687
+ "<|20.70|>": 51400,
688
+ "<|20.72|>": 51401,
689
+ "<|20.74|>": 51402,
690
+ "<|20.76|>": 51403,
691
+ "<|20.78|>": 51404,
692
+ "<|20.80|>": 51405,
693
+ "<|20.82|>": 51406,
694
+ "<|20.84|>": 51407,
695
+ "<|20.86|>": 51408,
696
+ "<|20.88|>": 51409,
697
+ "<|20.90|>": 51410,
698
+ "<|20.92|>": 51411,
699
+ "<|20.94|>": 51412,
700
+ "<|20.96|>": 51413,
701
+ "<|20.98|>": 51414,
702
+ "<|21.00|>": 51415,
703
+ "<|21.02|>": 51416,
704
+ "<|21.04|>": 51417,
705
+ "<|21.06|>": 51418,
706
+ "<|21.08|>": 51419,
707
+ "<|21.10|>": 51420,
708
+ "<|21.12|>": 51421,
709
+ "<|21.14|>": 51422,
710
+ "<|21.16|>": 51423,
711
+ "<|21.18|>": 51424,
712
+ "<|21.20|>": 51425,
713
+ "<|21.22|>": 51426,
714
+ "<|21.24|>": 51427,
715
+ "<|21.26|>": 51428,
716
+ "<|21.28|>": 51429,
717
+ "<|21.30|>": 51430,
718
+ "<|21.32|>": 51431,
719
+ "<|21.34|>": 51432,
720
+ "<|21.36|>": 51433,
721
+ "<|21.38|>": 51434,
722
+ "<|21.40|>": 51435,
723
+ "<|21.42|>": 51436,
724
+ "<|21.44|>": 51437,
725
+ "<|21.46|>": 51438,
726
+ "<|21.48|>": 51439,
727
+ "<|21.50|>": 51440,
728
+ "<|21.52|>": 51441,
729
+ "<|21.54|>": 51442,
730
+ "<|21.56|>": 51443,
731
+ "<|21.58|>": 51444,
732
+ "<|21.60|>": 51445,
733
+ "<|21.62|>": 51446,
734
+ "<|21.64|>": 51447,
735
+ "<|21.66|>": 51448,
736
+ "<|21.68|>": 51449,
737
+ "<|21.70|>": 51450,
738
+ "<|21.72|>": 51451,
739
+ "<|21.74|>": 51452,
740
+ "<|21.76|>": 51453,
741
+ "<|21.78|>": 51454,
742
+ "<|21.80|>": 51455,
743
+ "<|21.82|>": 51456,
744
+ "<|21.84|>": 51457,
745
+ "<|21.86|>": 51458,
746
+ "<|21.88|>": 51459,
747
+ "<|21.90|>": 51460,
748
+ "<|21.92|>": 51461,
749
+ "<|21.94|>": 51462,
750
+ "<|21.96|>": 51463,
751
+ "<|21.98|>": 51464,
752
+ "<|22.00|>": 51465,
753
+ "<|22.02|>": 51466,
754
+ "<|22.04|>": 51467,
755
+ "<|22.06|>": 51468,
756
+ "<|22.08|>": 51469,
757
+ "<|22.10|>": 51470,
758
+ "<|22.12|>": 51471,
759
+ "<|22.14|>": 51472,
760
+ "<|22.16|>": 51473,
761
+ "<|22.18|>": 51474,
762
+ "<|22.20|>": 51475,
763
+ "<|22.22|>": 51476,
764
+ "<|22.24|>": 51477,
765
+ "<|22.26|>": 51478,
766
+ "<|22.28|>": 51479,
767
+ "<|22.30|>": 51480,
768
+ "<|22.32|>": 51481,
769
+ "<|22.34|>": 51482,
770
+ "<|22.36|>": 51483,
771
+ "<|22.38|>": 51484,
772
+ "<|22.40|>": 51485,
773
+ "<|22.42|>": 51486,
774
+ "<|22.44|>": 51487,
775
+ "<|22.46|>": 51488,
776
+ "<|22.48|>": 51489,
777
+ "<|22.50|>": 51490,
778
+ "<|22.52|>": 51491,
779
+ "<|22.54|>": 51492,
780
+ "<|22.56|>": 51493,
781
+ "<|22.58|>": 51494,
782
+ "<|22.60|>": 51495,
783
+ "<|22.62|>": 51496,
784
+ "<|22.64|>": 51497,
785
+ "<|22.66|>": 51498,
786
+ "<|22.68|>": 51499,
787
+ "<|22.70|>": 51500,
788
+ "<|22.72|>": 51501,
789
+ "<|22.74|>": 51502,
790
+ "<|22.76|>": 51503,
791
+ "<|22.78|>": 51504,
792
+ "<|22.80|>": 51505,
793
+ "<|22.82|>": 51506,
794
+ "<|22.84|>": 51507,
795
+ "<|22.86|>": 51508,
796
+ "<|22.88|>": 51509,
797
+ "<|22.90|>": 51510,
798
+ "<|22.92|>": 51511,
799
+ "<|22.94|>": 51512,
800
+ "<|22.96|>": 51513,
801
+ "<|22.98|>": 51514,
802
+ "<|23.00|>": 51515,
803
+ "<|23.02|>": 51516,
804
+ "<|23.04|>": 51517,
805
+ "<|23.06|>": 51518,
806
+ "<|23.08|>": 51519,
807
+ "<|23.10|>": 51520,
808
+ "<|23.12|>": 51521,
809
+ "<|23.14|>": 51522,
810
+ "<|23.16|>": 51523,
811
+ "<|23.18|>": 51524,
812
+ "<|23.20|>": 51525,
813
+ "<|23.22|>": 51526,
814
+ "<|23.24|>": 51527,
815
+ "<|23.26|>": 51528,
816
+ "<|23.28|>": 51529,
817
+ "<|23.30|>": 51530,
818
+ "<|23.32|>": 51531,
819
+ "<|23.34|>": 51532,
820
+ "<|23.36|>": 51533,
821
+ "<|23.38|>": 51534,
822
+ "<|23.40|>": 51535,
823
+ "<|23.42|>": 51536,
824
+ "<|23.44|>": 51537,
825
+ "<|23.46|>": 51538,
826
+ "<|23.48|>": 51539,
827
+ "<|23.50|>": 51540,
828
+ "<|23.52|>": 51541,
829
+ "<|23.54|>": 51542,
830
+ "<|23.56|>": 51543,
831
+ "<|23.58|>": 51544,
832
+ "<|23.60|>": 51545,
833
+ "<|23.62|>": 51546,
834
+ "<|23.64|>": 51547,
835
+ "<|23.66|>": 51548,
836
+ "<|23.68|>": 51549,
837
+ "<|23.70|>": 51550,
838
+ "<|23.72|>": 51551,
839
+ "<|23.74|>": 51552,
840
+ "<|23.76|>": 51553,
841
+ "<|23.78|>": 51554,
842
+ "<|23.80|>": 51555,
843
+ "<|23.82|>": 51556,
844
+ "<|23.84|>": 51557,
845
+ "<|23.86|>": 51558,
846
+ "<|23.88|>": 51559,
847
+ "<|23.90|>": 51560,
848
+ "<|23.92|>": 51561,
849
+ "<|23.94|>": 51562,
850
+ "<|23.96|>": 51563,
851
+ "<|23.98|>": 51564,
852
+ "<|24.00|>": 51565,
853
+ "<|24.02|>": 51566,
854
+ "<|24.04|>": 51567,
855
+ "<|24.06|>": 51568,
856
+ "<|24.08|>": 51569,
857
+ "<|24.10|>": 51570,
858
+ "<|24.12|>": 51571,
859
+ "<|24.14|>": 51572,
860
+ "<|24.16|>": 51573,
861
+ "<|24.18|>": 51574,
862
+ "<|24.20|>": 51575,
863
+ "<|24.22|>": 51576,
864
+ "<|24.24|>": 51577,
865
+ "<|24.26|>": 51578,
866
+ "<|24.28|>": 51579,
867
+ "<|24.30|>": 51580,
868
+ "<|24.32|>": 51581,
869
+ "<|24.34|>": 51582,
870
+ "<|24.36|>": 51583,
871
+ "<|24.38|>": 51584,
872
+ "<|24.40|>": 51585,
873
+ "<|24.42|>": 51586,
874
+ "<|24.44|>": 51587,
875
+ "<|24.46|>": 51588,
876
+ "<|24.48|>": 51589,
877
+ "<|24.50|>": 51590,
878
+ "<|24.52|>": 51591,
879
+ "<|24.54|>": 51592,
880
+ "<|24.56|>": 51593,
881
+ "<|24.58|>": 51594,
882
+ "<|24.60|>": 51595,
883
+ "<|24.62|>": 51596,
884
+ "<|24.64|>": 51597,
885
+ "<|24.66|>": 51598,
886
+ "<|24.68|>": 51599,
887
+ "<|24.70|>": 51600,
888
+ "<|24.72|>": 51601,
889
+ "<|24.74|>": 51602,
890
+ "<|24.76|>": 51603,
891
+ "<|24.78|>": 51604,
892
+ "<|24.80|>": 51605,
893
+ "<|24.82|>": 51606,
894
+ "<|24.84|>": 51607,
895
+ "<|24.86|>": 51608,
896
+ "<|24.88|>": 51609,
897
+ "<|24.90|>": 51610,
898
+ "<|24.92|>": 51611,
899
+ "<|24.94|>": 51612,
900
+ "<|24.96|>": 51613,
901
+ "<|24.98|>": 51614,
902
+ "<|25.00|>": 51615,
903
+ "<|25.02|>": 51616,
904
+ "<|25.04|>": 51617,
905
+ "<|25.06|>": 51618,
906
+ "<|25.08|>": 51619,
907
+ "<|25.10|>": 51620,
908
+ "<|25.12|>": 51621,
909
+ "<|25.14|>": 51622,
910
+ "<|25.16|>": 51623,
911
+ "<|25.18|>": 51624,
912
+ "<|25.20|>": 51625,
913
+ "<|25.22|>": 51626,
914
+ "<|25.24|>": 51627,
915
+ "<|25.26|>": 51628,
916
+ "<|25.28|>": 51629,
917
+ "<|25.30|>": 51630,
918
+ "<|25.32|>": 51631,
919
+ "<|25.34|>": 51632,
920
+ "<|25.36|>": 51633,
921
+ "<|25.38|>": 51634,
922
+ "<|25.40|>": 51635,
923
+ "<|25.42|>": 51636,
924
+ "<|25.44|>": 51637,
925
+ "<|25.46|>": 51638,
926
+ "<|25.48|>": 51639,
927
+ "<|25.50|>": 51640,
928
+ "<|25.52|>": 51641,
929
+ "<|25.54|>": 51642,
930
+ "<|25.56|>": 51643,
931
+ "<|25.58|>": 51644,
932
+ "<|25.60|>": 51645,
933
+ "<|25.62|>": 51646,
934
+ "<|25.64|>": 51647,
935
+ "<|25.66|>": 51648,
936
+ "<|25.68|>": 51649,
937
+ "<|25.70|>": 51650,
938
+ "<|25.72|>": 51651,
939
+ "<|25.74|>": 51652,
940
+ "<|25.76|>": 51653,
941
+ "<|25.78|>": 51654,
942
+ "<|25.80|>": 51655,
943
+ "<|25.82|>": 51656,
944
+ "<|25.84|>": 51657,
945
+ "<|25.86|>": 51658,
946
+ "<|25.88|>": 51659,
947
+ "<|25.90|>": 51660,
948
+ "<|25.92|>": 51661,
949
+ "<|25.94|>": 51662,
950
+ "<|25.96|>": 51663,
951
+ "<|25.98|>": 51664,
952
+ "<|26.00|>": 51665,
953
+ "<|26.02|>": 51666,
954
+ "<|26.04|>": 51667,
955
+ "<|26.06|>": 51668,
956
+ "<|26.08|>": 51669,
957
+ "<|26.10|>": 51670,
958
+ "<|26.12|>": 51671,
959
+ "<|26.14|>": 51672,
960
+ "<|26.16|>": 51673,
961
+ "<|26.18|>": 51674,
962
+ "<|26.20|>": 51675,
963
+ "<|26.22|>": 51676,
964
+ "<|26.24|>": 51677,
965
+ "<|26.26|>": 51678,
966
+ "<|26.28|>": 51679,
967
+ "<|26.30|>": 51680,
968
+ "<|26.32|>": 51681,
969
+ "<|26.34|>": 51682,
970
+ "<|26.36|>": 51683,
971
+ "<|26.38|>": 51684,
972
+ "<|26.40|>": 51685,
973
+ "<|26.42|>": 51686,
974
+ "<|26.44|>": 51687,
975
+ "<|26.46|>": 51688,
976
+ "<|26.48|>": 51689,
977
+ "<|26.50|>": 51690,
978
+ "<|26.52|>": 51691,
979
+ "<|26.54|>": 51692,
980
+ "<|26.56|>": 51693,
981
+ "<|26.58|>": 51694,
982
+ "<|26.60|>": 51695,
983
+ "<|26.62|>": 51696,
984
+ "<|26.64|>": 51697,
985
+ "<|26.66|>": 51698,
986
+ "<|26.68|>": 51699,
987
+ "<|26.70|>": 51700,
988
+ "<|26.72|>": 51701,
989
+ "<|26.74|>": 51702,
990
+ "<|26.76|>": 51703,
991
+ "<|26.78|>": 51704,
992
+ "<|26.80|>": 51705,
993
+ "<|26.82|>": 51706,
994
+ "<|26.84|>": 51707,
995
+ "<|26.86|>": 51708,
996
+ "<|26.88|>": 51709,
997
+ "<|26.90|>": 51710,
998
+ "<|26.92|>": 51711,
999
+ "<|26.94|>": 51712,
1000
+ "<|26.96|>": 51713,
1001
+ "<|26.98|>": 51714,
1002
+ "<|27.00|>": 51715,
1003
+ "<|27.02|>": 51716,
1004
+ "<|27.04|>": 51717,
1005
+ "<|27.06|>": 51718,
1006
+ "<|27.08|>": 51719,
1007
+ "<|27.10|>": 51720,
1008
+ "<|27.12|>": 51721,
1009
+ "<|27.14|>": 51722,
1010
+ "<|27.16|>": 51723,
1011
+ "<|27.18|>": 51724,
1012
+ "<|27.20|>": 51725,
1013
+ "<|27.22|>": 51726,
1014
+ "<|27.24|>": 51727,
1015
+ "<|27.26|>": 51728,
1016
+ "<|27.28|>": 51729,
1017
+ "<|27.30|>": 51730,
1018
+ "<|27.32|>": 51731,
1019
+ "<|27.34|>": 51732,
1020
+ "<|27.36|>": 51733,
1021
+ "<|27.38|>": 51734,
1022
+ "<|27.40|>": 51735,
1023
+ "<|27.42|>": 51736,
1024
+ "<|27.44|>": 51737,
1025
+ "<|27.46|>": 51738,
1026
+ "<|27.48|>": 51739,
1027
+ "<|27.50|>": 51740,
1028
+ "<|27.52|>": 51741,
1029
+ "<|27.54|>": 51742,
1030
+ "<|27.56|>": 51743,
1031
+ "<|27.58|>": 51744,
1032
+ "<|27.60|>": 51745,
1033
+ "<|27.62|>": 51746,
1034
+ "<|27.64|>": 51747,
1035
+ "<|27.66|>": 51748,
1036
+ "<|27.68|>": 51749,
1037
+ "<|27.70|>": 51750,
1038
+ "<|27.72|>": 51751,
1039
+ "<|27.74|>": 51752,
1040
+ "<|27.76|>": 51753,
1041
+ "<|27.78|>": 51754,
1042
+ "<|27.80|>": 51755,
1043
+ "<|27.82|>": 51756,
1044
+ "<|27.84|>": 51757,
1045
+ "<|27.86|>": 51758,
1046
+ "<|27.88|>": 51759,
1047
+ "<|27.90|>": 51760,
1048
+ "<|27.92|>": 51761,
1049
+ "<|27.94|>": 51762,
1050
+ "<|27.96|>": 51763,
1051
+ "<|27.98|>": 51764,
1052
+ "<|28.00|>": 51765,
1053
+ "<|28.02|>": 51766,
1054
+ "<|28.04|>": 51767,
1055
+ "<|28.06|>": 51768,
1056
+ "<|28.08|>": 51769,
1057
+ "<|28.10|>": 51770,
1058
+ "<|28.12|>": 51771,
1059
+ "<|28.14|>": 51772,
1060
+ "<|28.16|>": 51773,
1061
+ "<|28.18|>": 51774,
1062
+ "<|28.20|>": 51775,
1063
+ "<|28.22|>": 51776,
1064
+ "<|28.24|>": 51777,
1065
+ "<|28.26|>": 51778,
1066
+ "<|28.28|>": 51779,
1067
+ "<|28.30|>": 51780,
1068
+ "<|28.32|>": 51781,
1069
+ "<|28.34|>": 51782,
1070
+ "<|28.36|>": 51783,
1071
+ "<|28.38|>": 51784,
1072
+ "<|28.40|>": 51785,
1073
+ "<|28.42|>": 51786,
1074
+ "<|28.44|>": 51787,
1075
+ "<|28.46|>": 51788,
1076
+ "<|28.48|>": 51789,
1077
+ "<|28.50|>": 51790,
1078
+ "<|28.52|>": 51791,
1079
+ "<|28.54|>": 51792,
1080
+ "<|28.56|>": 51793,
1081
+ "<|28.58|>": 51794,
1082
+ "<|28.60|>": 51795,
1083
+ "<|28.62|>": 51796,
1084
+ "<|28.64|>": 51797,
1085
+ "<|28.66|>": 51798,
1086
+ "<|28.68|>": 51799,
1087
+ "<|28.70|>": 51800,
1088
+ "<|28.72|>": 51801,
1089
+ "<|28.74|>": 51802,
1090
+ "<|28.76|>": 51803,
1091
+ "<|28.78|>": 51804,
1092
+ "<|28.80|>": 51805,
1093
+ "<|28.82|>": 51806,
1094
+ "<|28.84|>": 51807,
1095
+ "<|28.86|>": 51808,
1096
+ "<|28.88|>": 51809,
1097
+ "<|28.90|>": 51810,
1098
+ "<|28.92|>": 51811,
1099
+ "<|28.94|>": 51812,
1100
+ "<|28.96|>": 51813,
1101
+ "<|28.98|>": 51814,
1102
+ "<|29.00|>": 51815,
1103
+ "<|29.02|>": 51816,
1104
+ "<|29.04|>": 51817,
1105
+ "<|29.06|>": 51818,
1106
+ "<|29.08|>": 51819,
1107
+ "<|29.10|>": 51820,
1108
+ "<|29.12|>": 51821,
1109
+ "<|29.14|>": 51822,
1110
+ "<|29.16|>": 51823,
1111
+ "<|29.18|>": 51824,
1112
+ "<|29.20|>": 51825,
1113
+ "<|29.22|>": 51826,
1114
+ "<|29.24|>": 51827,
1115
+ "<|29.26|>": 51828,
1116
+ "<|29.28|>": 51829,
1117
+ "<|29.30|>": 51830,
1118
+ "<|29.32|>": 51831,
1119
+ "<|29.34|>": 51832,
1120
+ "<|29.36|>": 51833,
1121
+ "<|29.38|>": 51834,
1122
+ "<|29.40|>": 51835,
1123
+ "<|29.42|>": 51836,
1124
+ "<|29.44|>": 51837,
1125
+ "<|29.46|>": 51838,
1126
+ "<|29.48|>": 51839,
1127
+ "<|29.50|>": 51840,
1128
+ "<|29.52|>": 51841,
1129
+ "<|29.54|>": 51842,
1130
+ "<|29.56|>": 51843,
1131
+ "<|29.58|>": 51844,
1132
+ "<|29.60|>": 51845,
1133
+ "<|29.62|>": 51846,
1134
+ "<|29.64|>": 51847,
1135
+ "<|29.66|>": 51848,
1136
+ "<|29.68|>": 51849,
1137
+ "<|29.70|>": 51850,
1138
+ "<|29.72|>": 51851,
1139
+ "<|29.74|>": 51852,
1140
+ "<|29.76|>": 51853,
1141
+ "<|29.78|>": 51854,
1142
+ "<|29.80|>": 51855,
1143
+ "<|29.82|>": 51856,
1144
+ "<|29.84|>": 51857,
1145
+ "<|29.86|>": 51858,
1146
+ "<|29.88|>": 51859,
1147
+ "<|29.90|>": 51860,
1148
+ "<|29.92|>": 51861,
1149
+ "<|29.94|>": 51862,
1150
+ "<|29.96|>": 51863,
1151
+ "<|29.98|>": 51864,
1152
+ "<|3.00|>": 50515,
1153
+ "<|3.02|>": 50516,
1154
+ "<|3.04|>": 50517,
1155
+ "<|3.06|>": 50518,
1156
+ "<|3.08|>": 50519,
1157
+ "<|3.10|>": 50520,
1158
+ "<|3.12|>": 50521,
1159
+ "<|3.14|>": 50522,
1160
+ "<|3.16|>": 50523,
1161
+ "<|3.18|>": 50524,
1162
+ "<|3.20|>": 50525,
1163
+ "<|3.22|>": 50526,
1164
+ "<|3.24|>": 50527,
1165
+ "<|3.26|>": 50528,
1166
+ "<|3.28|>": 50529,
1167
+ "<|3.30|>": 50530,
1168
+ "<|3.32|>": 50531,
1169
+ "<|3.34|>": 50532,
1170
+ "<|3.36|>": 50533,
1171
+ "<|3.38|>": 50534,
1172
+ "<|3.40|>": 50535,
1173
+ "<|3.42|>": 50536,
1174
+ "<|3.44|>": 50537,
1175
+ "<|3.46|>": 50538,
1176
+ "<|3.48|>": 50539,
1177
+ "<|3.50|>": 50540,
1178
+ "<|3.52|>": 50541,
1179
+ "<|3.54|>": 50542,
1180
+ "<|3.56|>": 50543,
1181
+ "<|3.58|>": 50544,
1182
+ "<|3.60|>": 50545,
1183
+ "<|3.62|>": 50546,
1184
+ "<|3.64|>": 50547,
1185
+ "<|3.66|>": 50548,
1186
+ "<|3.68|>": 50549,
1187
+ "<|3.70|>": 50550,
1188
+ "<|3.72|>": 50551,
1189
+ "<|3.74|>": 50552,
1190
+ "<|3.76|>": 50553,
1191
+ "<|3.78|>": 50554,
1192
+ "<|3.80|>": 50555,
1193
+ "<|3.82|>": 50556,
1194
+ "<|3.84|>": 50557,
1195
+ "<|3.86|>": 50558,
1196
+ "<|3.88|>": 50559,
1197
+ "<|3.90|>": 50560,
1198
+ "<|3.92|>": 50561,
1199
+ "<|3.94|>": 50562,
1200
+ "<|3.96|>": 50563,
1201
+ "<|3.98|>": 50564,
1202
+ "<|30.00|>": 51865,
1203
+ "<|4.00|>": 50565,
1204
+ "<|4.02|>": 50566,
1205
+ "<|4.04|>": 50567,
1206
+ "<|4.06|>": 50568,
1207
+ "<|4.08|>": 50569,
1208
+ "<|4.10|>": 50570,
1209
+ "<|4.12|>": 50571,
1210
+ "<|4.14|>": 50572,
1211
+ "<|4.16|>": 50573,
1212
+ "<|4.18|>": 50574,
1213
+ "<|4.20|>": 50575,
1214
+ "<|4.22|>": 50576,
1215
+ "<|4.24|>": 50577,
1216
+ "<|4.26|>": 50578,
1217
+ "<|4.28|>": 50579,
1218
+ "<|4.30|>": 50580,
1219
+ "<|4.32|>": 50581,
1220
+ "<|4.34|>": 50582,
1221
+ "<|4.36|>": 50583,
1222
+ "<|4.38|>": 50584,
1223
+ "<|4.40|>": 50585,
1224
+ "<|4.42|>": 50586,
1225
+ "<|4.44|>": 50587,
1226
+ "<|4.46|>": 50588,
1227
+ "<|4.48|>": 50589,
1228
+ "<|4.50|>": 50590,
1229
+ "<|4.52|>": 50591,
1230
+ "<|4.54|>": 50592,
1231
+ "<|4.56|>": 50593,
1232
+ "<|4.58|>": 50594,
1233
+ "<|4.60|>": 50595,
1234
+ "<|4.62|>": 50596,
1235
+ "<|4.64|>": 50597,
1236
+ "<|4.66|>": 50598,
1237
+ "<|4.68|>": 50599,
1238
+ "<|4.70|>": 50600,
1239
+ "<|4.72|>": 50601,
1240
+ "<|4.74|>": 50602,
1241
+ "<|4.76|>": 50603,
1242
+ "<|4.78|>": 50604,
1243
+ "<|4.80|>": 50605,
1244
+ "<|4.82|>": 50606,
1245
+ "<|4.84|>": 50607,
1246
+ "<|4.86|>": 50608,
1247
+ "<|4.88|>": 50609,
1248
+ "<|4.90|>": 50610,
1249
+ "<|4.92|>": 50611,
1250
+ "<|4.94|>": 50612,
1251
+ "<|4.96|>": 50613,
1252
+ "<|4.98|>": 50614,
1253
+ "<|5.00|>": 50615,
1254
+ "<|5.02|>": 50616,
1255
+ "<|5.04|>": 50617,
1256
+ "<|5.06|>": 50618,
1257
+ "<|5.08|>": 50619,
1258
+ "<|5.10|>": 50620,
1259
+ "<|5.12|>": 50621,
1260
+ "<|5.14|>": 50622,
1261
+ "<|5.16|>": 50623,
1262
+ "<|5.18|>": 50624,
1263
+ "<|5.20|>": 50625,
1264
+ "<|5.22|>": 50626,
1265
+ "<|5.24|>": 50627,
1266
+ "<|5.26|>": 50628,
1267
+ "<|5.28|>": 50629,
1268
+ "<|5.30|>": 50630,
1269
+ "<|5.32|>": 50631,
1270
+ "<|5.34|>": 50632,
1271
+ "<|5.36|>": 50633,
1272
+ "<|5.38|>": 50634,
1273
+ "<|5.40|>": 50635,
1274
+ "<|5.42|>": 50636,
1275
+ "<|5.44|>": 50637,
1276
+ "<|5.46|>": 50638,
1277
+ "<|5.48|>": 50639,
1278
+ "<|5.50|>": 50640,
1279
+ "<|5.52|>": 50641,
1280
+ "<|5.54|>": 50642,
1281
+ "<|5.56|>": 50643,
1282
+ "<|5.58|>": 50644,
1283
+ "<|5.60|>": 50645,
1284
+ "<|5.62|>": 50646,
1285
+ "<|5.64|>": 50647,
1286
+ "<|5.66|>": 50648,
1287
+ "<|5.68|>": 50649,
1288
+ "<|5.70|>": 50650,
1289
+ "<|5.72|>": 50651,
1290
+ "<|5.74|>": 50652,
1291
+ "<|5.76|>": 50653,
1292
+ "<|5.78|>": 50654,
1293
+ "<|5.80|>": 50655,
1294
+ "<|5.82|>": 50656,
1295
+ "<|5.84|>": 50657,
1296
+ "<|5.86|>": 50658,
1297
+ "<|5.88|>": 50659,
1298
+ "<|5.90|>": 50660,
1299
+ "<|5.92|>": 50661,
1300
+ "<|5.94|>": 50662,
1301
+ "<|5.96|>": 50663,
1302
+ "<|5.98|>": 50664,
1303
+ "<|6.00|>": 50665,
1304
+ "<|6.02|>": 50666,
1305
+ "<|6.04|>": 50667,
1306
+ "<|6.06|>": 50668,
1307
+ "<|6.08|>": 50669,
1308
+ "<|6.10|>": 50670,
1309
+ "<|6.12|>": 50671,
1310
+ "<|6.14|>": 50672,
1311
+ "<|6.16|>": 50673,
1312
+ "<|6.18|>": 50674,
1313
+ "<|6.20|>": 50675,
1314
+ "<|6.22|>": 50676,
1315
+ "<|6.24|>": 50677,
1316
+ "<|6.26|>": 50678,
1317
+ "<|6.28|>": 50679,
1318
+ "<|6.30|>": 50680,
1319
+ "<|6.32|>": 50681,
1320
+ "<|6.34|>": 50682,
1321
+ "<|6.36|>": 50683,
1322
+ "<|6.38|>": 50684,
1323
+ "<|6.40|>": 50685,
1324
+ "<|6.42|>": 50686,
1325
+ "<|6.44|>": 50687,
1326
+ "<|6.46|>": 50688,
1327
+ "<|6.48|>": 50689,
1328
+ "<|6.50|>": 50690,
1329
+ "<|6.52|>": 50691,
1330
+ "<|6.54|>": 50692,
1331
+ "<|6.56|>": 50693,
1332
+ "<|6.58|>": 50694,
1333
+ "<|6.60|>": 50695,
1334
+ "<|6.62|>": 50696,
1335
+ "<|6.64|>": 50697,
1336
+ "<|6.66|>": 50698,
1337
+ "<|6.68|>": 50699,
1338
+ "<|6.70|>": 50700,
1339
+ "<|6.72|>": 50701,
1340
+ "<|6.74|>": 50702,
1341
+ "<|6.76|>": 50703,
1342
+ "<|6.78|>": 50704,
1343
+ "<|6.80|>": 50705,
1344
+ "<|6.82|>": 50706,
1345
+ "<|6.84|>": 50707,
1346
+ "<|6.86|>": 50708,
1347
+ "<|6.88|>": 50709,
1348
+ "<|6.90|>": 50710,
1349
+ "<|6.92|>": 50711,
1350
+ "<|6.94|>": 50712,
1351
+ "<|6.96|>": 50713,
1352
+ "<|6.98|>": 50714,
1353
+ "<|7.00|>": 50715,
1354
+ "<|7.02|>": 50716,
1355
+ "<|7.04|>": 50717,
1356
+ "<|7.06|>": 50718,
1357
+ "<|7.08|>": 50719,
1358
+ "<|7.10|>": 50720,
1359
+ "<|7.12|>": 50721,
1360
+ "<|7.14|>": 50722,
1361
+ "<|7.16|>": 50723,
1362
+ "<|7.18|>": 50724,
1363
+ "<|7.20|>": 50725,
1364
+ "<|7.22|>": 50726,
1365
+ "<|7.24|>": 50727,
1366
+ "<|7.26|>": 50728,
1367
+ "<|7.28|>": 50729,
1368
+ "<|7.30|>": 50730,
1369
+ "<|7.32|>": 50731,
1370
+ "<|7.34|>": 50732,
1371
+ "<|7.36|>": 50733,
1372
+ "<|7.38|>": 50734,
1373
+ "<|7.40|>": 50735,
1374
+ "<|7.42|>": 50736,
1375
+ "<|7.44|>": 50737,
1376
+ "<|7.46|>": 50738,
1377
+ "<|7.48|>": 50739,
1378
+ "<|7.50|>": 50740,
1379
+ "<|7.52|>": 50741,
1380
+ "<|7.54|>": 50742,
1381
+ "<|7.56|>": 50743,
1382
+ "<|7.58|>": 50744,
1383
+ "<|7.60|>": 50745,
1384
+ "<|7.62|>": 50746,
1385
+ "<|7.64|>": 50747,
1386
+ "<|7.66|>": 50748,
1387
+ "<|7.68|>": 50749,
1388
+ "<|7.70|>": 50750,
1389
+ "<|7.72|>": 50751,
1390
+ "<|7.74|>": 50752,
1391
+ "<|7.76|>": 50753,
1392
+ "<|7.78|>": 50754,
1393
+ "<|7.80|>": 50755,
1394
+ "<|7.82|>": 50756,
1395
+ "<|7.84|>": 50757,
1396
+ "<|7.86|>": 50758,
1397
+ "<|7.88|>": 50759,
1398
+ "<|7.90|>": 50760,
1399
+ "<|7.92|>": 50761,
1400
+ "<|7.94|>": 50762,
1401
+ "<|7.96|>": 50763,
1402
+ "<|7.98|>": 50764,
1403
+ "<|8.00|>": 50765,
1404
+ "<|8.02|>": 50766,
1405
+ "<|8.04|>": 50767,
1406
+ "<|8.06|>": 50768,
1407
+ "<|8.08|>": 50769,
1408
+ "<|8.10|>": 50770,
1409
+ "<|8.12|>": 50771,
1410
+ "<|8.14|>": 50772,
1411
+ "<|8.16|>": 50773,
1412
+ "<|8.18|>": 50774,
1413
+ "<|8.20|>": 50775,
1414
+ "<|8.22|>": 50776,
1415
+ "<|8.24|>": 50777,
1416
+ "<|8.26|>": 50778,
1417
+ "<|8.28|>": 50779,
1418
+ "<|8.30|>": 50780,
1419
+ "<|8.32|>": 50781,
1420
+ "<|8.34|>": 50782,
1421
+ "<|8.36|>": 50783,
1422
+ "<|8.38|>": 50784,
1423
+ "<|8.40|>": 50785,
1424
+ "<|8.42|>": 50786,
1425
+ "<|8.44|>": 50787,
1426
+ "<|8.46|>": 50788,
1427
+ "<|8.48|>": 50789,
1428
+ "<|8.50|>": 50790,
1429
+ "<|8.52|>": 50791,
1430
+ "<|8.54|>": 50792,
1431
+ "<|8.56|>": 50793,
1432
+ "<|8.58|>": 50794,
1433
+ "<|8.60|>": 50795,
1434
+ "<|8.62|>": 50796,
1435
+ "<|8.64|>": 50797,
1436
+ "<|8.66|>": 50798,
1437
+ "<|8.68|>": 50799,
1438
+ "<|8.70|>": 50800,
1439
+ "<|8.72|>": 50801,
1440
+ "<|8.74|>": 50802,
1441
+ "<|8.76|>": 50803,
1442
+ "<|8.78|>": 50804,
1443
+ "<|8.80|>": 50805,
1444
+ "<|8.82|>": 50806,
1445
+ "<|8.84|>": 50807,
1446
+ "<|8.86|>": 50808,
1447
+ "<|8.88|>": 50809,
1448
+ "<|8.90|>": 50810,
1449
+ "<|8.92|>": 50811,
1450
+ "<|8.94|>": 50812,
1451
+ "<|8.96|>": 50813,
1452
+ "<|8.98|>": 50814,
1453
+ "<|9.00|>": 50815,
1454
+ "<|9.02|>": 50816,
1455
+ "<|9.04|>": 50817,
1456
+ "<|9.06|>": 50818,
1457
+ "<|9.08|>": 50819,
1458
+ "<|9.10|>": 50820,
1459
+ "<|9.12|>": 50821,
1460
+ "<|9.14|>": 50822,
1461
+ "<|9.16|>": 50823,
1462
+ "<|9.18|>": 50824,
1463
+ "<|9.20|>": 50825,
1464
+ "<|9.22|>": 50826,
1465
+ "<|9.24|>": 50827,
1466
+ "<|9.26|>": 50828,
1467
+ "<|9.28|>": 50829,
1468
+ "<|9.30|>": 50830,
1469
+ "<|9.32|>": 50831,
1470
+ "<|9.34|>": 50832,
1471
+ "<|9.36|>": 50833,
1472
+ "<|9.38|>": 50834,
1473
+ "<|9.40|>": 50835,
1474
+ "<|9.42|>": 50836,
1475
+ "<|9.44|>": 50837,
1476
+ "<|9.46|>": 50838,
1477
+ "<|9.48|>": 50839,
1478
+ "<|9.50|>": 50840,
1479
+ "<|9.52|>": 50841,
1480
+ "<|9.54|>": 50842,
1481
+ "<|9.56|>": 50843,
1482
+ "<|9.58|>": 50844,
1483
+ "<|9.60|>": 50845,
1484
+ "<|9.62|>": 50846,
1485
+ "<|9.64|>": 50847,
1486
+ "<|9.66|>": 50848,
1487
+ "<|9.68|>": 50849,
1488
+ "<|9.70|>": 50850,
1489
+ "<|9.72|>": 50851,
1490
+ "<|9.74|>": 50852,
1491
+ "<|9.76|>": 50853,
1492
+ "<|9.78|>": 50854,
1493
+ "<|9.80|>": 50855,
1494
+ "<|9.82|>": 50856,
1495
+ "<|9.84|>": 50857,
1496
+ "<|9.86|>": 50858,
1497
+ "<|9.88|>": 50859,
1498
+ "<|9.90|>": 50860,
1499
+ "<|9.92|>": 50861,
1500
+ "<|9.94|>": 50862,
1501
+ "<|9.96|>": 50863,
1502
+ "<|9.98|>": 50864,
1503
+ "<|af|>": 50327,
1504
+ "<|am|>": 50334,
1505
+ "<|ar|>": 50272,
1506
+ "<|as|>": 50350,
1507
+ "<|az|>": 50304,
1508
+ "<|ba|>": 50355,
1509
+ "<|be|>": 50330,
1510
+ "<|bg|>": 50292,
1511
+ "<|bn|>": 50302,
1512
+ "<|bo|>": 50347,
1513
+ "<|br|>": 50309,
1514
+ "<|bs|>": 50315,
1515
+ "<|ca|>": 50270,
1516
+ "<|cs|>": 50283,
1517
+ "<|cy|>": 50297,
1518
+ "<|da|>": 50285,
1519
+ "<|de|>": 50261,
1520
+ "<|el|>": 50281,
1521
+ "<|endoftext|>": 50257,
1522
+ "<|en|>": 50259,
1523
+ "<|es|>": 50262,
1524
+ "<|et|>": 50307,
1525
+ "<|eu|>": 50310,
1526
+ "<|fa|>": 50300,
1527
+ "<|fi|>": 50277,
1528
+ "<|fo|>": 50338,
1529
+ "<|fr|>": 50265,
1530
+ "<|gl|>": 50319,
1531
+ "<|gu|>": 50333,
1532
+ "<|haw|>": 50352,
1533
+ "<|ha|>": 50354,
1534
+ "<|he|>": 50279,
1535
+ "<|hi|>": 50276,
1536
+ "<|hr|>": 50291,
1537
+ "<|ht|>": 50339,
1538
+ "<|hu|>": 50286,
1539
+ "<|hy|>": 50312,
1540
+ "<|id|>": 50275,
1541
+ "<|is|>": 50311,
1542
+ "<|it|>": 50274,
1543
+ "<|ja|>": 50266,
1544
+ "<|jw|>": 50356,
1545
+ "<|ka|>": 50329,
1546
+ "<|kk|>": 50316,
1547
+ "<|km|>": 50323,
1548
+ "<|kn|>": 50306,
1549
+ "<|ko|>": 50264,
1550
+ "<|la|>": 50294,
1551
+ "<|lb|>": 50345,
1552
+ "<|ln|>": 50353,
1553
+ "<|lo|>": 50336,
1554
+ "<|lt|>": 50293,
1555
+ "<|lv|>": 50301,
1556
+ "<|mg|>": 50349,
1557
+ "<|mi|>": 50295,
1558
+ "<|mk|>": 50308,
1559
+ "<|ml|>": 50296,
1560
+ "<|mn|>": 50314,
1561
+ "<|mr|>": 50320,
1562
+ "<|ms|>": 50282,
1563
+ "<|mt|>": 50343,
1564
+ "<|my|>": 50346,
1565
+ "<|ne|>": 50313,
1566
+ "<|nl|>": 50271,
1567
+ "<|nn|>": 50342,
1568
+ "<|nospeech|>": 50363,
1569
+ "<|notimestamps|>": 50364,
1570
+ "<|no|>": 50288,
1571
+ "<|oc|>": 50328,
1572
+ "<|pa|>": 50321,
1573
+ "<|pl|>": 50269,
1574
+ "<|ps|>": 50340,
1575
+ "<|pt|>": 50267,
1576
+ "<|ro|>": 50284,
1577
+ "<|ru|>": 50263,
1578
+ "<|sa|>": 50344,
1579
+ "<|sd|>": 50332,
1580
+ "<|si|>": 50322,
1581
+ "<|sk|>": 50298,
1582
+ "<|sl|>": 50305,
1583
+ "<|sn|>": 50324,
1584
+ "<|so|>": 50326,
1585
+ "<|sq|>": 50317,
1586
+ "<|sr|>": 50303,
1587
+ "<|startoflm|>": 50361,
1588
+ "<|startofprev|>": 50362,
1589
+ "<|startoftranscript|>": 50258,
1590
+ "<|su|>": 50357,
1591
+ "<|sv|>": 50273,
1592
+ "<|sw|>": 50318,
1593
+ "<|ta|>": 50287,
1594
+ "<|te|>": 50299,
1595
+ "<|tg|>": 50331,
1596
+ "<|th|>": 50289,
1597
+ "<|tk|>": 50341,
1598
+ "<|tl|>": 50348,
1599
+ "<|transcribe|>": 50360,
1600
+ "<|translate|>": 50359,
1601
+ "<|tr|>": 50268,
1602
+ "<|tt|>": 50351,
1603
+ "<|uk|>": 50280,
1604
+ "<|ur|>": 50290,
1605
+ "<|uz|>": 50337,
1606
+ "<|vi|>": 50278,
1607
+ "<|yi|>": 50335,
1608
+ "<|yo|>": 50325,
1609
+ "<|yue|>": 50358,
1610
+ "<|zh|>": 50260
1611
+ }
checkpoint-50-epoch-0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:663f44309c7b1ec405df5e5a462de1c283b3ca905e6bf171d632717871aedaca
3
+ size 3025686376
checkpoint-50-epoch-0/model_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b395c8a7e2bda655c415580106288d0387c227efd641bf4e11c1cd735fdb37a
3
+ size 4361070048
checkpoint-50-epoch-0/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9a02fc70602c6dc05d04dd56ca90df52a0c919c689a1c76bd0cfbf453173e87
3
+ size 955539578
checkpoint-50-epoch-0/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd2a733977ad85c9c935ee727f71e29775400be043213b5438f84bcf87a179e8
3
+ size 14344
checkpoint-50-epoch-0/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5607f6de446164d9d9adb8b91c44cec55b14aa391e24ba5637c08b834eedda2a
3
+ size 1064
config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/whisper-large-v3",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "apply_spec_augment": false,
6
+ "architectures": [
7
+ "WhisperForConditionalGeneration"
8
+ ],
9
+ "attention_dropout": 0.0,
10
+ "begin_suppress_tokens": [
11
+ 220,
12
+ 50257
13
+ ],
14
+ "bos_token_id": 50257,
15
+ "classifier_proj_size": 256,
16
+ "d_model": 1280,
17
+ "decoder_attention_heads": 20,
18
+ "decoder_ffn_dim": 5120,
19
+ "decoder_layerdrop": 0.0,
20
+ "decoder_layers": 2,
21
+ "decoder_start_token_id": 50258,
22
+ "dropout": 0.0,
23
+ "encoder_attention_heads": 20,
24
+ "encoder_ffn_dim": 5120,
25
+ "encoder_layerdrop": 0.0,
26
+ "encoder_layers": 32,
27
+ "eos_token_id": 50257,
28
+ "init_std": 0.02,
29
+ "is_encoder_decoder": true,
30
+ "mask_feature_length": 10,
31
+ "mask_feature_min_masks": 0,
32
+ "mask_feature_prob": 0.0,
33
+ "mask_time_length": 10,
34
+ "mask_time_min_masks": 2,
35
+ "mask_time_prob": 0.05,
36
+ "max_length": 448,
37
+ "max_source_positions": 1500,
38
+ "max_target_positions": 448,
39
+ "median_filter_width": 7,
40
+ "model_type": "whisper",
41
+ "num_hidden_layers": 32,
42
+ "num_mel_bins": 128,
43
+ "pad_token_id": 50256,
44
+ "scale_embedding": false,
45
+ "torch_dtype": "float32",
46
+ "transformers_version": "4.40.1",
47
+ "use_cache": true,
48
+ "use_weighted_layer_sum": false,
49
+ "vocab_size": 51866
50
+ }
create_student_model.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Initialise a student Whisper model from a pre-trained teacher model for
18
+ teacher-student distillation.
19
+ """
20
+
21
+ import argparse
22
+ import copy
23
+ import logging
24
+
25
+ import numpy as np
26
+ import torch
27
+ from transformers import GenerationConfig, WhisperForConditionalGeneration, WhisperProcessor
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def parse_args():
34
+ parser = argparse.ArgumentParser(
35
+ description="Initialise a student Whisper model from a teacher model, copying the relevant layer weights and adjusting the processor as necessary."
36
+ )
37
+ parser.add_argument(
38
+ "--teacher_checkpoint",
39
+ type=str,
40
+ required=True,
41
+ help="The HF Hub ID of the teacher checkpoint.",
42
+ )
43
+ parser.add_argument(
44
+ "--subfolder",
45
+ type=str,
46
+ default="",
47
+ help="In case the relevant teacher weights are located inside a subfolder of the model repo on huggingface.co, you "
48
+ "can specify the folder name here.",
49
+ )
50
+ parser.add_argument(
51
+ "--encoder_layers",
52
+ type=int,
53
+ default=None,
54
+ help="Number of encoder layers to use in the student model. Defaults to all layers from the teacher.",
55
+ )
56
+ parser.add_argument(
57
+ "--decoder_layers",
58
+ type=int,
59
+ default=2,
60
+ help="Number of decoder layers to use in the student model. Defaults to 2 layers.",
61
+ )
62
+ parser.add_argument(
63
+ "--save_dir",
64
+ type=str,
65
+ required=True,
66
+ help="Where to save the student weights and processor.",
67
+ )
68
+ parser.add_argument(
69
+ "--push_to_hub",
70
+ type=bool,
71
+ required=False,
72
+ default=False,
73
+ help="Whether to push the student weights and processor to the Hub.",
74
+ )
75
+ parser.add_argument(
76
+ "--cache_dir",
77
+ type=str,
78
+ default=None,
79
+ help="Where to store the pretrained models downloaded from huggingface.co",
80
+ )
81
+
82
+ args = parser.parse_args()
83
+ return args
84
+
85
+
86
+ def init_student_model_from_teacher(
87
+ teacher_checkpoint,
88
+ encoder_layers=None,
89
+ decoder_layers=2,
90
+ save_dir=None,
91
+ push_to_hub=None,
92
+ cache_dir=None,
93
+ subfolder="",
94
+ ):
95
+ teacher_model = WhisperForConditionalGeneration.from_pretrained(
96
+ teacher_checkpoint,
97
+ cache_dir=cache_dir,
98
+ subfolder=subfolder,
99
+ low_cpu_mem_usage=True,
100
+ )
101
+ processor = WhisperProcessor.from_pretrained(teacher_checkpoint)
102
+ generation_config = GenerationConfig.from_pretrained(teacher_checkpoint)
103
+ generation_config.forced_decoder_ids = None
104
+
105
+ teacher_config = teacher_model.config
106
+ teacher_encoder_layers = teacher_config.encoder_layers
107
+ teacher_decoder_layers = teacher_config.decoder_layers
108
+
109
+ student_config = copy.deepcopy(teacher_config)
110
+ student_config.update(
111
+ {
112
+ "encoder_layers": encoder_layers if encoder_layers is not None else teacher_encoder_layers,
113
+ "decoder_layers": decoder_layers,
114
+ }
115
+ )
116
+
117
+ encoder_mapping = np.linspace(0, teacher_encoder_layers - 1, student_config.encoder_layers, dtype=int)
118
+ encoder_mapping[-1] = teacher_encoder_layers - 1
119
+
120
+ encoder_map = {}
121
+ for student_layer, teacher_layer in enumerate(encoder_mapping):
122
+ encoder_map[teacher_layer] = student_layer
123
+
124
+ decoder_mapping = np.linspace(0, teacher_decoder_layers - 1, student_config.decoder_layers, dtype=int)
125
+ decoder_mapping[-1] = teacher_decoder_layers - 1
126
+
127
+ decoder_map = {}
128
+ for student_layer, teacher_layer in enumerate(decoder_mapping):
129
+ decoder_map[teacher_layer] = student_layer
130
+
131
+ # init the student params from the teacher model
132
+ student_model = WhisperForConditionalGeneration(student_config)
133
+ missing_keys, unexpected_keys = student_model.load_state_dict(teacher_model.state_dict(), strict=False)
134
+ if len(missing_keys) > 0:
135
+ raise RuntimeError(
136
+ "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
137
+ f"Missing key(s) in state_dict: {missing_keys}"
138
+ )
139
+ if decoder_layers == teacher_decoder_layers:
140
+ decoder_keys = [key for key in unexpected_keys if "model.decoder.layers" in key]
141
+ if len(decoder_keys) > 0:
142
+ raise RuntimeError(
143
+ "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
144
+ f"Unexpected key(s) in state_dict: {decoder_keys}"
145
+ )
146
+ if encoder_layers == teacher_encoder_layers:
147
+ encoder_keys = [key for key in unexpected_keys if "model.encoder.layers" in key]
148
+ if len(encoder_keys) > 0:
149
+ raise RuntimeError(
150
+ "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
151
+ f"Unexpected key(s) in state_dict: {encoder_keys}"
152
+ )
153
+
154
+ for layer in range(teacher_decoder_layers):
155
+ if layer in decoder_map:
156
+ # re-introduce pre-defined layers from the teacher
157
+ student_model.model.decoder.layers[decoder_map[layer]].load_state_dict(
158
+ teacher_model.model.decoder.layers[layer].state_dict()
159
+ )
160
+
161
+ if encoder_layers is not None:
162
+ for layer in range(teacher_encoder_layers):
163
+ if layer in encoder_map:
164
+ # re-introduce pre-defined layers from the teacher
165
+ student_model.model.encoder.layers[encoder_map[layer]].load_state_dict(
166
+ teacher_model.model.encoder.layers[layer].state_dict()
167
+ )
168
+
169
+ # remove the teacher params and model
170
+ del teacher_model
171
+
172
+ # save the converted weights and model
173
+ if save_dir is not None:
174
+ student_model.save_pretrained(save_dir)
175
+ # we also need to correctly save the processor and generation config
176
+ processor.save_pretrained(save_dir)
177
+ generation_config.save_pretrained(save_dir)
178
+
179
+ # check we can do a forward pass with the saved model - first load the weights and processor
180
+ logger.info("Checking we can load the saved model...")
181
+ student_model = WhisperForConditionalGeneration.from_pretrained(
182
+ save_dir,
183
+ low_cpu_mem_usage=True,
184
+ )
185
+ processor = WhisperProcessor.from_pretrained(save_dir)
186
+
187
+ # define some random inputs
188
+ input_features = processor(np.ones(16000), sampling_rate=16000, return_tensors="pt").input_features
189
+ decoder_start_token_id = student_model.config.decoder_start_token_id
190
+ decoder_input_ids = torch.ones((input_features.shape[0], 1), dtype=torch.long) * decoder_start_token_id
191
+
192
+ # do a forward pass - outputs will be gibberish for the initialised model so we can't check them
193
+ # but we make can sure the model runs as expected
194
+ logger.info("Checking we can run the converted model forward...")
195
+ _ = student_model(input_features, decoder_input_ids=decoder_input_ids).logits
196
+ logger.info("Conversion successful!")
197
+
198
+ if push_to_hub:
199
+ student_model.push_to_hub(save_dir)
200
+ processor.push_to_hub(save_dir)
201
+ generation_config.push_to_hub(save_dir)
202
+
203
+
204
+ if __name__ == "__main__":
205
+ args = parse_args()
206
+
207
+ init_student_model_from_teacher(
208
+ teacher_checkpoint=args.teacher_checkpoint,
209
+ encoder_layers=args.encoder_layers,
210
+ decoder_layers=args.decoder_layers,
211
+ save_dir=args.save_dir,
212
+ push_to_hub=args.push_to_hub,
213
+ cache_dir=args.cache_dir,
214
+ subfolder=args.subfolder,
215
+ )
distil-large-v3-init/added_tokens.json ADDED
@@ -0,0 +1,1611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|0.00|>": 50365,
3
+ "<|0.02|>": 50366,
4
+ "<|0.04|>": 50367,
5
+ "<|0.06|>": 50368,
6
+ "<|0.08|>": 50369,
7
+ "<|0.10|>": 50370,
8
+ "<|0.12|>": 50371,
9
+ "<|0.14|>": 50372,
10
+ "<|0.16|>": 50373,
11
+ "<|0.18|>": 50374,
12
+ "<|0.20|>": 50375,
13
+ "<|0.22|>": 50376,
14
+ "<|0.24|>": 50377,
15
+ "<|0.26|>": 50378,
16
+ "<|0.28|>": 50379,
17
+ "<|0.30|>": 50380,
18
+ "<|0.32|>": 50381,
19
+ "<|0.34|>": 50382,
20
+ "<|0.36|>": 50383,
21
+ "<|0.38|>": 50384,
22
+ "<|0.40|>": 50385,
23
+ "<|0.42|>": 50386,
24
+ "<|0.44|>": 50387,
25
+ "<|0.46|>": 50388,
26
+ "<|0.48|>": 50389,
27
+ "<|0.50|>": 50390,
28
+ "<|0.52|>": 50391,
29
+ "<|0.54|>": 50392,
30
+ "<|0.56|>": 50393,
31
+ "<|0.58|>": 50394,
32
+ "<|0.60|>": 50395,
33
+ "<|0.62|>": 50396,
34
+ "<|0.64|>": 50397,
35
+ "<|0.66|>": 50398,
36
+ "<|0.68|>": 50399,
37
+ "<|0.70|>": 50400,
38
+ "<|0.72|>": 50401,
39
+ "<|0.74|>": 50402,
40
+ "<|0.76|>": 50403,
41
+ "<|0.78|>": 50404,
42
+ "<|0.80|>": 50405,
43
+ "<|0.82|>": 50406,
44
+ "<|0.84|>": 50407,
45
+ "<|0.86|>": 50408,
46
+ "<|0.88|>": 50409,
47
+ "<|0.90|>": 50410,
48
+ "<|0.92|>": 50411,
49
+ "<|0.94|>": 50412,
50
+ "<|0.96|>": 50413,
51
+ "<|0.98|>": 50414,
52
+ "<|1.00|>": 50415,
53
+ "<|1.02|>": 50416,
54
+ "<|1.04|>": 50417,
55
+ "<|1.06|>": 50418,
56
+ "<|1.08|>": 50419,
57
+ "<|1.10|>": 50420,
58
+ "<|1.12|>": 50421,
59
+ "<|1.14|>": 50422,
60
+ "<|1.16|>": 50423,
61
+ "<|1.18|>": 50424,
62
+ "<|1.20|>": 50425,
63
+ "<|1.22|>": 50426,
64
+ "<|1.24|>": 50427,
65
+ "<|1.26|>": 50428,
66
+ "<|1.28|>": 50429,
67
+ "<|1.30|>": 50430,
68
+ "<|1.32|>": 50431,
69
+ "<|1.34|>": 50432,
70
+ "<|1.36|>": 50433,
71
+ "<|1.38|>": 50434,
72
+ "<|1.40|>": 50435,
73
+ "<|1.42|>": 50436,
74
+ "<|1.44|>": 50437,
75
+ "<|1.46|>": 50438,
76
+ "<|1.48|>": 50439,
77
+ "<|1.50|>": 50440,
78
+ "<|1.52|>": 50441,
79
+ "<|1.54|>": 50442,
80
+ "<|1.56|>": 50443,
81
+ "<|1.58|>": 50444,
82
+ "<|1.60|>": 50445,
83
+ "<|1.62|>": 50446,
84
+ "<|1.64|>": 50447,
85
+ "<|1.66|>": 50448,
86
+ "<|1.68|>": 50449,
87
+ "<|1.70|>": 50450,
88
+ "<|1.72|>": 50451,
89
+ "<|1.74|>": 50452,
90
+ "<|1.76|>": 50453,
91
+ "<|1.78|>": 50454,
92
+ "<|1.80|>": 50455,
93
+ "<|1.82|>": 50456,
94
+ "<|1.84|>": 50457,
95
+ "<|1.86|>": 50458,
96
+ "<|1.88|>": 50459,
97
+ "<|1.90|>": 50460,
98
+ "<|1.92|>": 50461,
99
+ "<|1.94|>": 50462,
100
+ "<|1.96|>": 50463,
101
+ "<|1.98|>": 50464,
102
+ "<|10.00|>": 50865,
103
+ "<|10.02|>": 50866,
104
+ "<|10.04|>": 50867,
105
+ "<|10.06|>": 50868,
106
+ "<|10.08|>": 50869,
107
+ "<|10.10|>": 50870,
108
+ "<|10.12|>": 50871,
109
+ "<|10.14|>": 50872,
110
+ "<|10.16|>": 50873,
111
+ "<|10.18|>": 50874,
112
+ "<|10.20|>": 50875,
113
+ "<|10.22|>": 50876,
114
+ "<|10.24|>": 50877,
115
+ "<|10.26|>": 50878,
116
+ "<|10.28|>": 50879,
117
+ "<|10.30|>": 50880,
118
+ "<|10.32|>": 50881,
119
+ "<|10.34|>": 50882,
120
+ "<|10.36|>": 50883,
121
+ "<|10.38|>": 50884,
122
+ "<|10.40|>": 50885,
123
+ "<|10.42|>": 50886,
124
+ "<|10.44|>": 50887,
125
+ "<|10.46|>": 50888,
126
+ "<|10.48|>": 50889,
127
+ "<|10.50|>": 50890,
128
+ "<|10.52|>": 50891,
129
+ "<|10.54|>": 50892,
130
+ "<|10.56|>": 50893,
131
+ "<|10.58|>": 50894,
132
+ "<|10.60|>": 50895,
133
+ "<|10.62|>": 50896,
134
+ "<|10.64|>": 50897,
135
+ "<|10.66|>": 50898,
136
+ "<|10.68|>": 50899,
137
+ "<|10.70|>": 50900,
138
+ "<|10.72|>": 50901,
139
+ "<|10.74|>": 50902,
140
+ "<|10.76|>": 50903,
141
+ "<|10.78|>": 50904,
142
+ "<|10.80|>": 50905,
143
+ "<|10.82|>": 50906,
144
+ "<|10.84|>": 50907,
145
+ "<|10.86|>": 50908,
146
+ "<|10.88|>": 50909,
147
+ "<|10.90|>": 50910,
148
+ "<|10.92|>": 50911,
149
+ "<|10.94|>": 50912,
150
+ "<|10.96|>": 50913,
151
+ "<|10.98|>": 50914,
152
+ "<|11.00|>": 50915,
153
+ "<|11.02|>": 50916,
154
+ "<|11.04|>": 50917,
155
+ "<|11.06|>": 50918,
156
+ "<|11.08|>": 50919,
157
+ "<|11.10|>": 50920,
158
+ "<|11.12|>": 50921,
159
+ "<|11.14|>": 50922,
160
+ "<|11.16|>": 50923,
161
+ "<|11.18|>": 50924,
162
+ "<|11.20|>": 50925,
163
+ "<|11.22|>": 50926,
164
+ "<|11.24|>": 50927,
165
+ "<|11.26|>": 50928,
166
+ "<|11.28|>": 50929,
167
+ "<|11.30|>": 50930,
168
+ "<|11.32|>": 50931,
169
+ "<|11.34|>": 50932,
170
+ "<|11.36|>": 50933,
171
+ "<|11.38|>": 50934,
172
+ "<|11.40|>": 50935,
173
+ "<|11.42|>": 50936,
174
+ "<|11.44|>": 50937,
175
+ "<|11.46|>": 50938,
176
+ "<|11.48|>": 50939,
177
+ "<|11.50|>": 50940,
178
+ "<|11.52|>": 50941,
179
+ "<|11.54|>": 50942,
180
+ "<|11.56|>": 50943,
181
+ "<|11.58|>": 50944,
182
+ "<|11.60|>": 50945,
183
+ "<|11.62|>": 50946,
184
+ "<|11.64|>": 50947,
185
+ "<|11.66|>": 50948,
186
+ "<|11.68|>": 50949,
187
+ "<|11.70|>": 50950,
188
+ "<|11.72|>": 50951,
189
+ "<|11.74|>": 50952,
190
+ "<|11.76|>": 50953,
191
+ "<|11.78|>": 50954,
192
+ "<|11.80|>": 50955,
193
+ "<|11.82|>": 50956,
194
+ "<|11.84|>": 50957,
195
+ "<|11.86|>": 50958,
196
+ "<|11.88|>": 50959,
197
+ "<|11.90|>": 50960,
198
+ "<|11.92|>": 50961,
199
+ "<|11.94|>": 50962,
200
+ "<|11.96|>": 50963,
201
+ "<|11.98|>": 50964,
202
+ "<|12.00|>": 50965,
203
+ "<|12.02|>": 50966,
204
+ "<|12.04|>": 50967,
205
+ "<|12.06|>": 50968,
206
+ "<|12.08|>": 50969,
207
+ "<|12.10|>": 50970,
208
+ "<|12.12|>": 50971,
209
+ "<|12.14|>": 50972,
210
+ "<|12.16|>": 50973,
211
+ "<|12.18|>": 50974,
212
+ "<|12.20|>": 50975,
213
+ "<|12.22|>": 50976,
214
+ "<|12.24|>": 50977,
215
+ "<|12.26|>": 50978,
216
+ "<|12.28|>": 50979,
217
+ "<|12.30|>": 50980,
218
+ "<|12.32|>": 50981,
219
+ "<|12.34|>": 50982,
220
+ "<|12.36|>": 50983,
221
+ "<|12.38|>": 50984,
222
+ "<|12.40|>": 50985,
223
+ "<|12.42|>": 50986,
224
+ "<|12.44|>": 50987,
225
+ "<|12.46|>": 50988,
226
+ "<|12.48|>": 50989,
227
+ "<|12.50|>": 50990,
228
+ "<|12.52|>": 50991,
229
+ "<|12.54|>": 50992,
230
+ "<|12.56|>": 50993,
231
+ "<|12.58|>": 50994,
232
+ "<|12.60|>": 50995,
233
+ "<|12.62|>": 50996,
234
+ "<|12.64|>": 50997,
235
+ "<|12.66|>": 50998,
236
+ "<|12.68|>": 50999,
237
+ "<|12.70|>": 51000,
238
+ "<|12.72|>": 51001,
239
+ "<|12.74|>": 51002,
240
+ "<|12.76|>": 51003,
241
+ "<|12.78|>": 51004,
242
+ "<|12.80|>": 51005,
243
+ "<|12.82|>": 51006,
244
+ "<|12.84|>": 51007,
245
+ "<|12.86|>": 51008,
246
+ "<|12.88|>": 51009,
247
+ "<|12.90|>": 51010,
248
+ "<|12.92|>": 51011,
249
+ "<|12.94|>": 51012,
250
+ "<|12.96|>": 51013,
251
+ "<|12.98|>": 51014,
252
+ "<|13.00|>": 51015,
253
+ "<|13.02|>": 51016,
254
+ "<|13.04|>": 51017,
255
+ "<|13.06|>": 51018,
256
+ "<|13.08|>": 51019,
257
+ "<|13.10|>": 51020,
258
+ "<|13.12|>": 51021,
259
+ "<|13.14|>": 51022,
260
+ "<|13.16|>": 51023,
261
+ "<|13.18|>": 51024,
262
+ "<|13.20|>": 51025,
263
+ "<|13.22|>": 51026,
264
+ "<|13.24|>": 51027,
265
+ "<|13.26|>": 51028,
266
+ "<|13.28|>": 51029,
267
+ "<|13.30|>": 51030,
268
+ "<|13.32|>": 51031,
269
+ "<|13.34|>": 51032,
270
+ "<|13.36|>": 51033,
271
+ "<|13.38|>": 51034,
272
+ "<|13.40|>": 51035,
273
+ "<|13.42|>": 51036,
274
+ "<|13.44|>": 51037,
275
+ "<|13.46|>": 51038,
276
+ "<|13.48|>": 51039,
277
+ "<|13.50|>": 51040,
278
+ "<|13.52|>": 51041,
279
+ "<|13.54|>": 51042,
280
+ "<|13.56|>": 51043,
281
+ "<|13.58|>": 51044,
282
+ "<|13.60|>": 51045,
283
+ "<|13.62|>": 51046,
284
+ "<|13.64|>": 51047,
285
+ "<|13.66|>": 51048,
286
+ "<|13.68|>": 51049,
287
+ "<|13.70|>": 51050,
288
+ "<|13.72|>": 51051,
289
+ "<|13.74|>": 51052,
290
+ "<|13.76|>": 51053,
291
+ "<|13.78|>": 51054,
292
+ "<|13.80|>": 51055,
293
+ "<|13.82|>": 51056,
294
+ "<|13.84|>": 51057,
295
+ "<|13.86|>": 51058,
296
+ "<|13.88|>": 51059,
297
+ "<|13.90|>": 51060,
298
+ "<|13.92|>": 51061,
299
+ "<|13.94|>": 51062,
300
+ "<|13.96|>": 51063,
301
+ "<|13.98|>": 51064,
302
+ "<|14.00|>": 51065,
303
+ "<|14.02|>": 51066,
304
+ "<|14.04|>": 51067,
305
+ "<|14.06|>": 51068,
306
+ "<|14.08|>": 51069,
307
+ "<|14.10|>": 51070,
308
+ "<|14.12|>": 51071,
309
+ "<|14.14|>": 51072,
310
+ "<|14.16|>": 51073,
311
+ "<|14.18|>": 51074,
312
+ "<|14.20|>": 51075,
313
+ "<|14.22|>": 51076,
314
+ "<|14.24|>": 51077,
315
+ "<|14.26|>": 51078,
316
+ "<|14.28|>": 51079,
317
+ "<|14.30|>": 51080,
318
+ "<|14.32|>": 51081,
319
+ "<|14.34|>": 51082,
320
+ "<|14.36|>": 51083,
321
+ "<|14.38|>": 51084,
322
+ "<|14.40|>": 51085,
323
+ "<|14.42|>": 51086,
324
+ "<|14.44|>": 51087,
325
+ "<|14.46|>": 51088,
326
+ "<|14.48|>": 51089,
327
+ "<|14.50|>": 51090,
328
+ "<|14.52|>": 51091,
329
+ "<|14.54|>": 51092,
330
+ "<|14.56|>": 51093,
331
+ "<|14.58|>": 51094,
332
+ "<|14.60|>": 51095,
333
+ "<|14.62|>": 51096,
334
+ "<|14.64|>": 51097,
335
+ "<|14.66|>": 51098,
336
+ "<|14.68|>": 51099,
337
+ "<|14.70|>": 51100,
338
+ "<|14.72|>": 51101,
339
+ "<|14.74|>": 51102,
340
+ "<|14.76|>": 51103,
341
+ "<|14.78|>": 51104,
342
+ "<|14.80|>": 51105,
343
+ "<|14.82|>": 51106,
344
+ "<|14.84|>": 51107,
345
+ "<|14.86|>": 51108,
346
+ "<|14.88|>": 51109,
347
+ "<|14.90|>": 51110,
348
+ "<|14.92|>": 51111,
349
+ "<|14.94|>": 51112,
350
+ "<|14.96|>": 51113,
351
+ "<|14.98|>": 51114,
352
+ "<|15.00|>": 51115,
353
+ "<|15.02|>": 51116,
354
+ "<|15.04|>": 51117,
355
+ "<|15.06|>": 51118,
356
+ "<|15.08|>": 51119,
357
+ "<|15.10|>": 51120,
358
+ "<|15.12|>": 51121,
359
+ "<|15.14|>": 51122,
360
+ "<|15.16|>": 51123,
361
+ "<|15.18|>": 51124,
362
+ "<|15.20|>": 51125,
363
+ "<|15.22|>": 51126,
364
+ "<|15.24|>": 51127,
365
+ "<|15.26|>": 51128,
366
+ "<|15.28|>": 51129,
367
+ "<|15.30|>": 51130,
368
+ "<|15.32|>": 51131,
369
+ "<|15.34|>": 51132,
370
+ "<|15.36|>": 51133,
371
+ "<|15.38|>": 51134,
372
+ "<|15.40|>": 51135,
373
+ "<|15.42|>": 51136,
374
+ "<|15.44|>": 51137,
375
+ "<|15.46|>": 51138,
376
+ "<|15.48|>": 51139,
377
+ "<|15.50|>": 51140,
378
+ "<|15.52|>": 51141,
379
+ "<|15.54|>": 51142,
380
+ "<|15.56|>": 51143,
381
+ "<|15.58|>": 51144,
382
+ "<|15.60|>": 51145,
383
+ "<|15.62|>": 51146,
384
+ "<|15.64|>": 51147,
385
+ "<|15.66|>": 51148,
386
+ "<|15.68|>": 51149,
387
+ "<|15.70|>": 51150,
388
+ "<|15.72|>": 51151,
389
+ "<|15.74|>": 51152,
390
+ "<|15.76|>": 51153,
391
+ "<|15.78|>": 51154,
392
+ "<|15.80|>": 51155,
393
+ "<|15.82|>": 51156,
394
+ "<|15.84|>": 51157,
395
+ "<|15.86|>": 51158,
396
+ "<|15.88|>": 51159,
397
+ "<|15.90|>": 51160,
398
+ "<|15.92|>": 51161,
399
+ "<|15.94|>": 51162,
400
+ "<|15.96|>": 51163,
401
+ "<|15.98|>": 51164,
402
+ "<|16.00|>": 51165,
403
+ "<|16.02|>": 51166,
404
+ "<|16.04|>": 51167,
405
+ "<|16.06|>": 51168,
406
+ "<|16.08|>": 51169,
407
+ "<|16.10|>": 51170,
408
+ "<|16.12|>": 51171,
409
+ "<|16.14|>": 51172,
410
+ "<|16.16|>": 51173,
411
+ "<|16.18|>": 51174,
412
+ "<|16.20|>": 51175,
413
+ "<|16.22|>": 51176,
414
+ "<|16.24|>": 51177,
415
+ "<|16.26|>": 51178,
416
+ "<|16.28|>": 51179,
417
+ "<|16.30|>": 51180,
418
+ "<|16.32|>": 51181,
419
+ "<|16.34|>": 51182,
420
+ "<|16.36|>": 51183,
421
+ "<|16.38|>": 51184,
422
+ "<|16.40|>": 51185,
423
+ "<|16.42|>": 51186,
424
+ "<|16.44|>": 51187,
425
+ "<|16.46|>": 51188,
426
+ "<|16.48|>": 51189,
427
+ "<|16.50|>": 51190,
428
+ "<|16.52|>": 51191,
429
+ "<|16.54|>": 51192,
430
+ "<|16.56|>": 51193,
431
+ "<|16.58|>": 51194,
432
+ "<|16.60|>": 51195,
433
+ "<|16.62|>": 51196,
434
+ "<|16.64|>": 51197,
435
+ "<|16.66|>": 51198,
436
+ "<|16.68|>": 51199,
437
+ "<|16.70|>": 51200,
438
+ "<|16.72|>": 51201,
439
+ "<|16.74|>": 51202,
440
+ "<|16.76|>": 51203,
441
+ "<|16.78|>": 51204,
442
+ "<|16.80|>": 51205,
443
+ "<|16.82|>": 51206,
444
+ "<|16.84|>": 51207,
445
+ "<|16.86|>": 51208,
446
+ "<|16.88|>": 51209,
447
+ "<|16.90|>": 51210,
448
+ "<|16.92|>": 51211,
449
+ "<|16.94|>": 51212,
450
+ "<|16.96|>": 51213,
451
+ "<|16.98|>": 51214,
452
+ "<|17.00|>": 51215,
453
+ "<|17.02|>": 51216,
454
+ "<|17.04|>": 51217,
455
+ "<|17.06|>": 51218,
456
+ "<|17.08|>": 51219,
457
+ "<|17.10|>": 51220,
458
+ "<|17.12|>": 51221,
459
+ "<|17.14|>": 51222,
460
+ "<|17.16|>": 51223,
461
+ "<|17.18|>": 51224,
462
+ "<|17.20|>": 51225,
463
+ "<|17.22|>": 51226,
464
+ "<|17.24|>": 51227,
465
+ "<|17.26|>": 51228,
466
+ "<|17.28|>": 51229,
467
+ "<|17.30|>": 51230,
468
+ "<|17.32|>": 51231,
469
+ "<|17.34|>": 51232,
470
+ "<|17.36|>": 51233,
471
+ "<|17.38|>": 51234,
472
+ "<|17.40|>": 51235,
473
+ "<|17.42|>": 51236,
474
+ "<|17.44|>": 51237,
475
+ "<|17.46|>": 51238,
476
+ "<|17.48|>": 51239,
477
+ "<|17.50|>": 51240,
478
+ "<|17.52|>": 51241,
479
+ "<|17.54|>": 51242,
480
+ "<|17.56|>": 51243,
481
+ "<|17.58|>": 51244,
482
+ "<|17.60|>": 51245,
483
+ "<|17.62|>": 51246,
484
+ "<|17.64|>": 51247,
485
+ "<|17.66|>": 51248,
486
+ "<|17.68|>": 51249,
487
+ "<|17.70|>": 51250,
488
+ "<|17.72|>": 51251,
489
+ "<|17.74|>": 51252,
490
+ "<|17.76|>": 51253,
491
+ "<|17.78|>": 51254,
492
+ "<|17.80|>": 51255,
493
+ "<|17.82|>": 51256,
494
+ "<|17.84|>": 51257,
495
+ "<|17.86|>": 51258,
496
+ "<|17.88|>": 51259,
497
+ "<|17.90|>": 51260,
498
+ "<|17.92|>": 51261,
499
+ "<|17.94|>": 51262,
500
+ "<|17.96|>": 51263,
501
+ "<|17.98|>": 51264,
502
+ "<|18.00|>": 51265,
503
+ "<|18.02|>": 51266,
504
+ "<|18.04|>": 51267,
505
+ "<|18.06|>": 51268,
506
+ "<|18.08|>": 51269,
507
+ "<|18.10|>": 51270,
508
+ "<|18.12|>": 51271,
509
+ "<|18.14|>": 51272,
510
+ "<|18.16|>": 51273,
511
+ "<|18.18|>": 51274,
512
+ "<|18.20|>": 51275,
513
+ "<|18.22|>": 51276,
514
+ "<|18.24|>": 51277,
515
+ "<|18.26|>": 51278,
516
+ "<|18.28|>": 51279,
517
+ "<|18.30|>": 51280,
518
+ "<|18.32|>": 51281,
519
+ "<|18.34|>": 51282,
520
+ "<|18.36|>": 51283,
521
+ "<|18.38|>": 51284,
522
+ "<|18.40|>": 51285,
523
+ "<|18.42|>": 51286,
524
+ "<|18.44|>": 51287,
525
+ "<|18.46|>": 51288,
526
+ "<|18.48|>": 51289,
527
+ "<|18.50|>": 51290,
528
+ "<|18.52|>": 51291,
529
+ "<|18.54|>": 51292,
530
+ "<|18.56|>": 51293,
531
+ "<|18.58|>": 51294,
532
+ "<|18.60|>": 51295,
533
+ "<|18.62|>": 51296,
534
+ "<|18.64|>": 51297,
535
+ "<|18.66|>": 51298,
536
+ "<|18.68|>": 51299,
537
+ "<|18.70|>": 51300,
538
+ "<|18.72|>": 51301,
539
+ "<|18.74|>": 51302,
540
+ "<|18.76|>": 51303,
541
+ "<|18.78|>": 51304,
542
+ "<|18.80|>": 51305,
543
+ "<|18.82|>": 51306,
544
+ "<|18.84|>": 51307,
545
+ "<|18.86|>": 51308,
546
+ "<|18.88|>": 51309,
547
+ "<|18.90|>": 51310,
548
+ "<|18.92|>": 51311,
549
+ "<|18.94|>": 51312,
550
+ "<|18.96|>": 51313,
551
+ "<|18.98|>": 51314,
552
+ "<|19.00|>": 51315,
553
+ "<|19.02|>": 51316,
554
+ "<|19.04|>": 51317,
555
+ "<|19.06|>": 51318,
556
+ "<|19.08|>": 51319,
557
+ "<|19.10|>": 51320,
558
+ "<|19.12|>": 51321,
559
+ "<|19.14|>": 51322,
560
+ "<|19.16|>": 51323,
561
+ "<|19.18|>": 51324,
562
+ "<|19.20|>": 51325,
563
+ "<|19.22|>": 51326,
564
+ "<|19.24|>": 51327,
565
+ "<|19.26|>": 51328,
566
+ "<|19.28|>": 51329,
567
+ "<|19.30|>": 51330,
568
+ "<|19.32|>": 51331,
569
+ "<|19.34|>": 51332,
570
+ "<|19.36|>": 51333,
571
+ "<|19.38|>": 51334,
572
+ "<|19.40|>": 51335,
573
+ "<|19.42|>": 51336,
574
+ "<|19.44|>": 51337,
575
+ "<|19.46|>": 51338,
576
+ "<|19.48|>": 51339,
577
+ "<|19.50|>": 51340,
578
+ "<|19.52|>": 51341,
579
+ "<|19.54|>": 51342,
580
+ "<|19.56|>": 51343,
581
+ "<|19.58|>": 51344,
582
+ "<|19.60|>": 51345,
583
+ "<|19.62|>": 51346,
584
+ "<|19.64|>": 51347,
585
+ "<|19.66|>": 51348,
586
+ "<|19.68|>": 51349,
587
+ "<|19.70|>": 51350,
588
+ "<|19.72|>": 51351,
589
+ "<|19.74|>": 51352,
590
+ "<|19.76|>": 51353,
591
+ "<|19.78|>": 51354,
592
+ "<|19.80|>": 51355,
593
+ "<|19.82|>": 51356,
594
+ "<|19.84|>": 51357,
595
+ "<|19.86|>": 51358,
596
+ "<|19.88|>": 51359,
597
+ "<|19.90|>": 51360,
598
+ "<|19.92|>": 51361,
599
+ "<|19.94|>": 51362,
600
+ "<|19.96|>": 51363,
601
+ "<|19.98|>": 51364,
602
+ "<|2.00|>": 50465,
603
+ "<|2.02|>": 50466,
604
+ "<|2.04|>": 50467,
605
+ "<|2.06|>": 50468,
606
+ "<|2.08|>": 50469,
607
+ "<|2.10|>": 50470,
608
+ "<|2.12|>": 50471,
609
+ "<|2.14|>": 50472,
610
+ "<|2.16|>": 50473,
611
+ "<|2.18|>": 50474,
612
+ "<|2.20|>": 50475,
613
+ "<|2.22|>": 50476,
614
+ "<|2.24|>": 50477,
615
+ "<|2.26|>": 50478,
616
+ "<|2.28|>": 50479,
617
+ "<|2.30|>": 50480,
618
+ "<|2.32|>": 50481,
619
+ "<|2.34|>": 50482,
620
+ "<|2.36|>": 50483,
621
+ "<|2.38|>": 50484,
622
+ "<|2.40|>": 50485,
623
+ "<|2.42|>": 50486,
624
+ "<|2.44|>": 50487,
625
+ "<|2.46|>": 50488,
626
+ "<|2.48|>": 50489,
627
+ "<|2.50|>": 50490,
628
+ "<|2.52|>": 50491,
629
+ "<|2.54|>": 50492,
630
+ "<|2.56|>": 50493,
631
+ "<|2.58|>": 50494,
632
+ "<|2.60|>": 50495,
633
+ "<|2.62|>": 50496,
634
+ "<|2.64|>": 50497,
635
+ "<|2.66|>": 50498,
636
+ "<|2.68|>": 50499,
637
+ "<|2.70|>": 50500,
638
+ "<|2.72|>": 50501,
639
+ "<|2.74|>": 50502,
640
+ "<|2.76|>": 50503,
641
+ "<|2.78|>": 50504,
642
+ "<|2.80|>": 50505,
643
+ "<|2.82|>": 50506,
644
+ "<|2.84|>": 50507,
645
+ "<|2.86|>": 50508,
646
+ "<|2.88|>": 50509,
647
+ "<|2.90|>": 50510,
648
+ "<|2.92|>": 50511,
649
+ "<|2.94|>": 50512,
650
+ "<|2.96|>": 50513,
651
+ "<|2.98|>": 50514,
652
+ "<|20.00|>": 51365,
653
+ "<|20.02|>": 51366,
654
+ "<|20.04|>": 51367,
655
+ "<|20.06|>": 51368,
656
+ "<|20.08|>": 51369,
657
+ "<|20.10|>": 51370,
658
+ "<|20.12|>": 51371,
659
+ "<|20.14|>": 51372,
660
+ "<|20.16|>": 51373,
661
+ "<|20.18|>": 51374,
662
+ "<|20.20|>": 51375,
663
+ "<|20.22|>": 51376,
664
+ "<|20.24|>": 51377,
665
+ "<|20.26|>": 51378,
666
+ "<|20.28|>": 51379,
667
+ "<|20.30|>": 51380,
668
+ "<|20.32|>": 51381,
669
+ "<|20.34|>": 51382,
670
+ "<|20.36|>": 51383,
671
+ "<|20.38|>": 51384,
672
+ "<|20.40|>": 51385,
673
+ "<|20.42|>": 51386,
674
+ "<|20.44|>": 51387,
675
+ "<|20.46|>": 51388,
676
+ "<|20.48|>": 51389,
677
+ "<|20.50|>": 51390,
678
+ "<|20.52|>": 51391,
679
+ "<|20.54|>": 51392,
680
+ "<|20.56|>": 51393,
681
+ "<|20.58|>": 51394,
682
+ "<|20.60|>": 51395,
683
+ "<|20.62|>": 51396,
684
+ "<|20.64|>": 51397,
685
+ "<|20.66|>": 51398,
686
+ "<|20.68|>": 51399,
687
+ "<|20.70|>": 51400,
688
+ "<|20.72|>": 51401,
689
+ "<|20.74|>": 51402,
690
+ "<|20.76|>": 51403,
691
+ "<|20.78|>": 51404,
692
+ "<|20.80|>": 51405,
693
+ "<|20.82|>": 51406,
694
+ "<|20.84|>": 51407,
695
+ "<|20.86|>": 51408,
696
+ "<|20.88|>": 51409,
697
+ "<|20.90|>": 51410,
698
+ "<|20.92|>": 51411,
699
+ "<|20.94|>": 51412,
700
+ "<|20.96|>": 51413,
701
+ "<|20.98|>": 51414,
702
+ "<|21.00|>": 51415,
703
+ "<|21.02|>": 51416,
704
+ "<|21.04|>": 51417,
705
+ "<|21.06|>": 51418,
706
+ "<|21.08|>": 51419,
707
+ "<|21.10|>": 51420,
708
+ "<|21.12|>": 51421,
709
+ "<|21.14|>": 51422,
710
+ "<|21.16|>": 51423,
711
+ "<|21.18|>": 51424,
712
+ "<|21.20|>": 51425,
713
+ "<|21.22|>": 51426,
714
+ "<|21.24|>": 51427,
715
+ "<|21.26|>": 51428,
716
+ "<|21.28|>": 51429,
717
+ "<|21.30|>": 51430,
718
+ "<|21.32|>": 51431,
719
+ "<|21.34|>": 51432,
720
+ "<|21.36|>": 51433,
721
+ "<|21.38|>": 51434,
722
+ "<|21.40|>": 51435,
723
+ "<|21.42|>": 51436,
724
+ "<|21.44|>": 51437,
725
+ "<|21.46|>": 51438,
726
+ "<|21.48|>": 51439,
727
+ "<|21.50|>": 51440,
728
+ "<|21.52|>": 51441,
729
+ "<|21.54|>": 51442,
730
+ "<|21.56|>": 51443,
731
+ "<|21.58|>": 51444,
732
+ "<|21.60|>": 51445,
733
+ "<|21.62|>": 51446,
734
+ "<|21.64|>": 51447,
735
+ "<|21.66|>": 51448,
736
+ "<|21.68|>": 51449,
737
+ "<|21.70|>": 51450,
738
+ "<|21.72|>": 51451,
739
+ "<|21.74|>": 51452,
740
+ "<|21.76|>": 51453,
741
+ "<|21.78|>": 51454,
742
+ "<|21.80|>": 51455,
743
+ "<|21.82|>": 51456,
744
+ "<|21.84|>": 51457,
745
+ "<|21.86|>": 51458,
746
+ "<|21.88|>": 51459,
747
+ "<|21.90|>": 51460,
748
+ "<|21.92|>": 51461,
749
+ "<|21.94|>": 51462,
750
+ "<|21.96|>": 51463,
751
+ "<|21.98|>": 51464,
752
+ "<|22.00|>": 51465,
753
+ "<|22.02|>": 51466,
754
+ "<|22.04|>": 51467,
755
+ "<|22.06|>": 51468,
756
+ "<|22.08|>": 51469,
757
+ "<|22.10|>": 51470,
758
+ "<|22.12|>": 51471,
759
+ "<|22.14|>": 51472,
760
+ "<|22.16|>": 51473,
761
+ "<|22.18|>": 51474,
762
+ "<|22.20|>": 51475,
763
+ "<|22.22|>": 51476,
764
+ "<|22.24|>": 51477,
765
+ "<|22.26|>": 51478,
766
+ "<|22.28|>": 51479,
767
+ "<|22.30|>": 51480,
768
+ "<|22.32|>": 51481,
769
+ "<|22.34|>": 51482,
770
+ "<|22.36|>": 51483,
771
+ "<|22.38|>": 51484,
772
+ "<|22.40|>": 51485,
773
+ "<|22.42|>": 51486,
774
+ "<|22.44|>": 51487,
775
+ "<|22.46|>": 51488,
776
+ "<|22.48|>": 51489,
777
+ "<|22.50|>": 51490,
778
+ "<|22.52|>": 51491,
779
+ "<|22.54|>": 51492,
780
+ "<|22.56|>": 51493,
781
+ "<|22.58|>": 51494,
782
+ "<|22.60|>": 51495,
783
+ "<|22.62|>": 51496,
784
+ "<|22.64|>": 51497,
785
+ "<|22.66|>": 51498,
786
+ "<|22.68|>": 51499,
787
+ "<|22.70|>": 51500,
788
+ "<|22.72|>": 51501,
789
+ "<|22.74|>": 51502,
790
+ "<|22.76|>": 51503,
791
+ "<|22.78|>": 51504,
792
+ "<|22.80|>": 51505,
793
+ "<|22.82|>": 51506,
794
+ "<|22.84|>": 51507,
795
+ "<|22.86|>": 51508,
796
+ "<|22.88|>": 51509,
797
+ "<|22.90|>": 51510,
798
+ "<|22.92|>": 51511,
799
+ "<|22.94|>": 51512,
800
+ "<|22.96|>": 51513,
801
+ "<|22.98|>": 51514,
802
+ "<|23.00|>": 51515,
803
+ "<|23.02|>": 51516,
804
+ "<|23.04|>": 51517,
805
+ "<|23.06|>": 51518,
806
+ "<|23.08|>": 51519,
807
+ "<|23.10|>": 51520,
808
+ "<|23.12|>": 51521,
809
+ "<|23.14|>": 51522,
810
+ "<|23.16|>": 51523,
811
+ "<|23.18|>": 51524,
812
+ "<|23.20|>": 51525,
813
+ "<|23.22|>": 51526,
814
+ "<|23.24|>": 51527,
815
+ "<|23.26|>": 51528,
816
+ "<|23.28|>": 51529,
817
+ "<|23.30|>": 51530,
818
+ "<|23.32|>": 51531,
819
+ "<|23.34|>": 51532,
820
+ "<|23.36|>": 51533,
821
+ "<|23.38|>": 51534,
822
+ "<|23.40|>": 51535,
823
+ "<|23.42|>": 51536,
824
+ "<|23.44|>": 51537,
825
+ "<|23.46|>": 51538,
826
+ "<|23.48|>": 51539,
827
+ "<|23.50|>": 51540,
828
+ "<|23.52|>": 51541,
829
+ "<|23.54|>": 51542,
830
+ "<|23.56|>": 51543,
831
+ "<|23.58|>": 51544,
832
+ "<|23.60|>": 51545,
833
+ "<|23.62|>": 51546,
834
+ "<|23.64|>": 51547,
835
+ "<|23.66|>": 51548,
836
+ "<|23.68|>": 51549,
837
+ "<|23.70|>": 51550,
838
+ "<|23.72|>": 51551,
839
+ "<|23.74|>": 51552,
840
+ "<|23.76|>": 51553,
841
+ "<|23.78|>": 51554,
842
+ "<|23.80|>": 51555,
843
+ "<|23.82|>": 51556,
844
+ "<|23.84|>": 51557,
845
+ "<|23.86|>": 51558,
846
+ "<|23.88|>": 51559,
847
+ "<|23.90|>": 51560,
848
+ "<|23.92|>": 51561,
849
+ "<|23.94|>": 51562,
850
+ "<|23.96|>": 51563,
851
+ "<|23.98|>": 51564,
852
+ "<|24.00|>": 51565,
853
+ "<|24.02|>": 51566,
854
+ "<|24.04|>": 51567,
855
+ "<|24.06|>": 51568,
856
+ "<|24.08|>": 51569,
857
+ "<|24.10|>": 51570,
858
+ "<|24.12|>": 51571,
859
+ "<|24.14|>": 51572,
860
+ "<|24.16|>": 51573,
861
+ "<|24.18|>": 51574,
862
+ "<|24.20|>": 51575,
863
+ "<|24.22|>": 51576,
864
+ "<|24.24|>": 51577,
865
+ "<|24.26|>": 51578,
866
+ "<|24.28|>": 51579,
867
+ "<|24.30|>": 51580,
868
+ "<|24.32|>": 51581,
869
+ "<|24.34|>": 51582,
870
+ "<|24.36|>": 51583,
871
+ "<|24.38|>": 51584,
872
+ "<|24.40|>": 51585,
873
+ "<|24.42|>": 51586,
874
+ "<|24.44|>": 51587,
875
+ "<|24.46|>": 51588,
876
+ "<|24.48|>": 51589,
877
+ "<|24.50|>": 51590,
878
+ "<|24.52|>": 51591,
879
+ "<|24.54|>": 51592,
880
+ "<|24.56|>": 51593,
881
+ "<|24.58|>": 51594,
882
+ "<|24.60|>": 51595,
883
+ "<|24.62|>": 51596,
884
+ "<|24.64|>": 51597,
885
+ "<|24.66|>": 51598,
886
+ "<|24.68|>": 51599,
887
+ "<|24.70|>": 51600,
888
+ "<|24.72|>": 51601,
889
+ "<|24.74|>": 51602,
890
+ "<|24.76|>": 51603,
891
+ "<|24.78|>": 51604,
892
+ "<|24.80|>": 51605,
893
+ "<|24.82|>": 51606,
894
+ "<|24.84|>": 51607,
895
+ "<|24.86|>": 51608,
896
+ "<|24.88|>": 51609,
897
+ "<|24.90|>": 51610,
898
+ "<|24.92|>": 51611,
899
+ "<|24.94|>": 51612,
900
+ "<|24.96|>": 51613,
901
+ "<|24.98|>": 51614,
902
+ "<|25.00|>": 51615,
903
+ "<|25.02|>": 51616,
904
+ "<|25.04|>": 51617,
905
+ "<|25.06|>": 51618,
906
+ "<|25.08|>": 51619,
907
+ "<|25.10|>": 51620,
908
+ "<|25.12|>": 51621,
909
+ "<|25.14|>": 51622,
910
+ "<|25.16|>": 51623,
911
+ "<|25.18|>": 51624,
912
+ "<|25.20|>": 51625,
913
+ "<|25.22|>": 51626,
914
+ "<|25.24|>": 51627,
915
+ "<|25.26|>": 51628,
916
+ "<|25.28|>": 51629,
917
+ "<|25.30|>": 51630,
918
+ "<|25.32|>": 51631,
919
+ "<|25.34|>": 51632,
920
+ "<|25.36|>": 51633,
921
+ "<|25.38|>": 51634,
922
+ "<|25.40|>": 51635,
923
+ "<|25.42|>": 51636,
924
+ "<|25.44|>": 51637,
925
+ "<|25.46|>": 51638,
926
+ "<|25.48|>": 51639,
927
+ "<|25.50|>": 51640,
928
+ "<|25.52|>": 51641,
929
+ "<|25.54|>": 51642,
930
+ "<|25.56|>": 51643,
931
+ "<|25.58|>": 51644,
932
+ "<|25.60|>": 51645,
933
+ "<|25.62|>": 51646,
934
+ "<|25.64|>": 51647,
935
+ "<|25.66|>": 51648,
936
+ "<|25.68|>": 51649,
937
+ "<|25.70|>": 51650,
938
+ "<|25.72|>": 51651,
939
+ "<|25.74|>": 51652,
940
+ "<|25.76|>": 51653,
941
+ "<|25.78|>": 51654,
942
+ "<|25.80|>": 51655,
943
+ "<|25.82|>": 51656,
944
+ "<|25.84|>": 51657,
945
+ "<|25.86|>": 51658,
946
+ "<|25.88|>": 51659,
947
+ "<|25.90|>": 51660,
948
+ "<|25.92|>": 51661,
949
+ "<|25.94|>": 51662,
950
+ "<|25.96|>": 51663,
951
+ "<|25.98|>": 51664,
952
+ "<|26.00|>": 51665,
953
+ "<|26.02|>": 51666,
954
+ "<|26.04|>": 51667,
955
+ "<|26.06|>": 51668,
956
+ "<|26.08|>": 51669,
957
+ "<|26.10|>": 51670,
958
+ "<|26.12|>": 51671,
959
+ "<|26.14|>": 51672,
960
+ "<|26.16|>": 51673,
961
+ "<|26.18|>": 51674,
962
+ "<|26.20|>": 51675,
963
+ "<|26.22|>": 51676,
964
+ "<|26.24|>": 51677,
965
+ "<|26.26|>": 51678,
966
+ "<|26.28|>": 51679,
967
+ "<|26.30|>": 51680,
968
+ "<|26.32|>": 51681,
969
+ "<|26.34|>": 51682,
970
+ "<|26.36|>": 51683,
971
+ "<|26.38|>": 51684,
972
+ "<|26.40|>": 51685,
973
+ "<|26.42|>": 51686,
974
+ "<|26.44|>": 51687,
975
+ "<|26.46|>": 51688,
976
+ "<|26.48|>": 51689,
977
+ "<|26.50|>": 51690,
978
+ "<|26.52|>": 51691,
979
+ "<|26.54|>": 51692,
980
+ "<|26.56|>": 51693,
981
+ "<|26.58|>": 51694,
982
+ "<|26.60|>": 51695,
983
+ "<|26.62|>": 51696,
984
+ "<|26.64|>": 51697,
985
+ "<|26.66|>": 51698,
986
+ "<|26.68|>": 51699,
987
+ "<|26.70|>": 51700,
988
+ "<|26.72|>": 51701,
989
+ "<|26.74|>": 51702,
990
+ "<|26.76|>": 51703,
991
+ "<|26.78|>": 51704,
992
+ "<|26.80|>": 51705,
993
+ "<|26.82|>": 51706,
994
+ "<|26.84|>": 51707,
995
+ "<|26.86|>": 51708,
996
+ "<|26.88|>": 51709,
997
+ "<|26.90|>": 51710,
998
+ "<|26.92|>": 51711,
999
+ "<|26.94|>": 51712,
1000
+ "<|26.96|>": 51713,
1001
+ "<|26.98|>": 51714,
1002
+ "<|27.00|>": 51715,
1003
+ "<|27.02|>": 51716,
1004
+ "<|27.04|>": 51717,
1005
+ "<|27.06|>": 51718,
1006
+ "<|27.08|>": 51719,
1007
+ "<|27.10|>": 51720,
1008
+ "<|27.12|>": 51721,
1009
+ "<|27.14|>": 51722,
1010
+ "<|27.16|>": 51723,
1011
+ "<|27.18|>": 51724,
1012
+ "<|27.20|>": 51725,
1013
+ "<|27.22|>": 51726,
1014
+ "<|27.24|>": 51727,
1015
+ "<|27.26|>": 51728,
1016
+ "<|27.28|>": 51729,
1017
+ "<|27.30|>": 51730,
1018
+ "<|27.32|>": 51731,
1019
+ "<|27.34|>": 51732,
1020
+ "<|27.36|>": 51733,
1021
+ "<|27.38|>": 51734,
1022
+ "<|27.40|>": 51735,
1023
+ "<|27.42|>": 51736,
1024
+ "<|27.44|>": 51737,
1025
+ "<|27.46|>": 51738,
1026
+ "<|27.48|>": 51739,
1027
+ "<|27.50|>": 51740,
1028
+ "<|27.52|>": 51741,
1029
+ "<|27.54|>": 51742,
1030
+ "<|27.56|>": 51743,
1031
+ "<|27.58|>": 51744,
1032
+ "<|27.60|>": 51745,
1033
+ "<|27.62|>": 51746,
1034
+ "<|27.64|>": 51747,
1035
+ "<|27.66|>": 51748,
1036
+ "<|27.68|>": 51749,
1037
+ "<|27.70|>": 51750,
1038
+ "<|27.72|>": 51751,
1039
+ "<|27.74|>": 51752,
1040
+ "<|27.76|>": 51753,
1041
+ "<|27.78|>": 51754,
1042
+ "<|27.80|>": 51755,
1043
+ "<|27.82|>": 51756,
1044
+ "<|27.84|>": 51757,
1045
+ "<|27.86|>": 51758,
1046
+ "<|27.88|>": 51759,
1047
+ "<|27.90|>": 51760,
1048
+ "<|27.92|>": 51761,
1049
+ "<|27.94|>": 51762,
1050
+ "<|27.96|>": 51763,
1051
+ "<|27.98|>": 51764,
1052
+ "<|28.00|>": 51765,
1053
+ "<|28.02|>": 51766,
1054
+ "<|28.04|>": 51767,
1055
+ "<|28.06|>": 51768,
1056
+ "<|28.08|>": 51769,
1057
+ "<|28.10|>": 51770,
1058
+ "<|28.12|>": 51771,
1059
+ "<|28.14|>": 51772,
1060
+ "<|28.16|>": 51773,
1061
+ "<|28.18|>": 51774,
1062
+ "<|28.20|>": 51775,
1063
+ "<|28.22|>": 51776,
1064
+ "<|28.24|>": 51777,
1065
+ "<|28.26|>": 51778,
1066
+ "<|28.28|>": 51779,
1067
+ "<|28.30|>": 51780,
1068
+ "<|28.32|>": 51781,
1069
+ "<|28.34|>": 51782,
1070
+ "<|28.36|>": 51783,
1071
+ "<|28.38|>": 51784,
1072
+ "<|28.40|>": 51785,
1073
+ "<|28.42|>": 51786,
1074
+ "<|28.44|>": 51787,
1075
+ "<|28.46|>": 51788,
1076
+ "<|28.48|>": 51789,
1077
+ "<|28.50|>": 51790,
1078
+ "<|28.52|>": 51791,
1079
+ "<|28.54|>": 51792,
1080
+ "<|28.56|>": 51793,
1081
+ "<|28.58|>": 51794,
1082
+ "<|28.60|>": 51795,
1083
+ "<|28.62|>": 51796,
1084
+ "<|28.64|>": 51797,
1085
+ "<|28.66|>": 51798,
1086
+ "<|28.68|>": 51799,
1087
+ "<|28.70|>": 51800,
1088
+ "<|28.72|>": 51801,
1089
+ "<|28.74|>": 51802,
1090
+ "<|28.76|>": 51803,
1091
+ "<|28.78|>": 51804,
1092
+ "<|28.80|>": 51805,
1093
+ "<|28.82|>": 51806,
1094
+ "<|28.84|>": 51807,
1095
+ "<|28.86|>": 51808,
1096
+ "<|28.88|>": 51809,
1097
+ "<|28.90|>": 51810,
1098
+ "<|28.92|>": 51811,
1099
+ "<|28.94|>": 51812,
1100
+ "<|28.96|>": 51813,
1101
+ "<|28.98|>": 51814,
1102
+ "<|29.00|>": 51815,
1103
+ "<|29.02|>": 51816,
1104
+ "<|29.04|>": 51817,
1105
+ "<|29.06|>": 51818,
1106
+ "<|29.08|>": 51819,
1107
+ "<|29.10|>": 51820,
1108
+ "<|29.12|>": 51821,
1109
+ "<|29.14|>": 51822,
1110
+ "<|29.16|>": 51823,
1111
+ "<|29.18|>": 51824,
1112
+ "<|29.20|>": 51825,
1113
+ "<|29.22|>": 51826,
1114
+ "<|29.24|>": 51827,
1115
+ "<|29.26|>": 51828,
1116
+ "<|29.28|>": 51829,
1117
+ "<|29.30|>": 51830,
1118
+ "<|29.32|>": 51831,
1119
+ "<|29.34|>": 51832,
1120
+ "<|29.36|>": 51833,
1121
+ "<|29.38|>": 51834,
1122
+ "<|29.40|>": 51835,
1123
+ "<|29.42|>": 51836,
1124
+ "<|29.44|>": 51837,
1125
+ "<|29.46|>": 51838,
1126
+ "<|29.48|>": 51839,
1127
+ "<|29.50|>": 51840,
1128
+ "<|29.52|>": 51841,
1129
+ "<|29.54|>": 51842,
1130
+ "<|29.56|>": 51843,
1131
+ "<|29.58|>": 51844,
1132
+ "<|29.60|>": 51845,
1133
+ "<|29.62|>": 51846,
1134
+ "<|29.64|>": 51847,
1135
+ "<|29.66|>": 51848,
1136
+ "<|29.68|>": 51849,
1137
+ "<|29.70|>": 51850,
1138
+ "<|29.72|>": 51851,
1139
+ "<|29.74|>": 51852,
1140
+ "<|29.76|>": 51853,
1141
+ "<|29.78|>": 51854,
1142
+ "<|29.80|>": 51855,
1143
+ "<|29.82|>": 51856,
1144
+ "<|29.84|>": 51857,
1145
+ "<|29.86|>": 51858,
1146
+ "<|29.88|>": 51859,
1147
+ "<|29.90|>": 51860,
1148
+ "<|29.92|>": 51861,
1149
+ "<|29.94|>": 51862,
1150
+ "<|29.96|>": 51863,
1151
+ "<|29.98|>": 51864,
1152
+ "<|3.00|>": 50515,
1153
+ "<|3.02|>": 50516,
1154
+ "<|3.04|>": 50517,
1155
+ "<|3.06|>": 50518,
1156
+ "<|3.08|>": 50519,
1157
+ "<|3.10|>": 50520,
1158
+ "<|3.12|>": 50521,
1159
+ "<|3.14|>": 50522,
1160
+ "<|3.16|>": 50523,
1161
+ "<|3.18|>": 50524,
1162
+ "<|3.20|>": 50525,
1163
+ "<|3.22|>": 50526,
1164
+ "<|3.24|>": 50527,
1165
+ "<|3.26|>": 50528,
1166
+ "<|3.28|>": 50529,
1167
+ "<|3.30|>": 50530,
1168
+ "<|3.32|>": 50531,
1169
+ "<|3.34|>": 50532,
1170
+ "<|3.36|>": 50533,
1171
+ "<|3.38|>": 50534,
1172
+ "<|3.40|>": 50535,
1173
+ "<|3.42|>": 50536,
1174
+ "<|3.44|>": 50537,
1175
+ "<|3.46|>": 50538,
1176
+ "<|3.48|>": 50539,
1177
+ "<|3.50|>": 50540,
1178
+ "<|3.52|>": 50541,
1179
+ "<|3.54|>": 50542,
1180
+ "<|3.56|>": 50543,
1181
+ "<|3.58|>": 50544,
1182
+ "<|3.60|>": 50545,
1183
+ "<|3.62|>": 50546,
1184
+ "<|3.64|>": 50547,
1185
+ "<|3.66|>": 50548,
1186
+ "<|3.68|>": 50549,
1187
+ "<|3.70|>": 50550,
1188
+ "<|3.72|>": 50551,
1189
+ "<|3.74|>": 50552,
1190
+ "<|3.76|>": 50553,
1191
+ "<|3.78|>": 50554,
1192
+ "<|3.80|>": 50555,
1193
+ "<|3.82|>": 50556,
1194
+ "<|3.84|>": 50557,
1195
+ "<|3.86|>": 50558,
1196
+ "<|3.88|>": 50559,
1197
+ "<|3.90|>": 50560,
1198
+ "<|3.92|>": 50561,
1199
+ "<|3.94|>": 50562,
1200
+ "<|3.96|>": 50563,
1201
+ "<|3.98|>": 50564,
1202
+ "<|30.00|>": 51865,
1203
+ "<|4.00|>": 50565,
1204
+ "<|4.02|>": 50566,
1205
+ "<|4.04|>": 50567,
1206
+ "<|4.06|>": 50568,
1207
+ "<|4.08|>": 50569,
1208
+ "<|4.10|>": 50570,
1209
+ "<|4.12|>": 50571,
1210
+ "<|4.14|>": 50572,
1211
+ "<|4.16|>": 50573,
1212
+ "<|4.18|>": 50574,
1213
+ "<|4.20|>": 50575,
1214
+ "<|4.22|>": 50576,
1215
+ "<|4.24|>": 50577,
1216
+ "<|4.26|>": 50578,
1217
+ "<|4.28|>": 50579,
1218
+ "<|4.30|>": 50580,
1219
+ "<|4.32|>": 50581,
1220
+ "<|4.34|>": 50582,
1221
+ "<|4.36|>": 50583,
1222
+ "<|4.38|>": 50584,
1223
+ "<|4.40|>": 50585,
1224
+ "<|4.42|>": 50586,
1225
+ "<|4.44|>": 50587,
1226
+ "<|4.46|>": 50588,
1227
+ "<|4.48|>": 50589,
1228
+ "<|4.50|>": 50590,
1229
+ "<|4.52|>": 50591,
1230
+ "<|4.54|>": 50592,
1231
+ "<|4.56|>": 50593,
1232
+ "<|4.58|>": 50594,
1233
+ "<|4.60|>": 50595,
1234
+ "<|4.62|>": 50596,
1235
+ "<|4.64|>": 50597,
1236
+ "<|4.66|>": 50598,
1237
+ "<|4.68|>": 50599,
1238
+ "<|4.70|>": 50600,
1239
+ "<|4.72|>": 50601,
1240
+ "<|4.74|>": 50602,
1241
+ "<|4.76|>": 50603,
1242
+ "<|4.78|>": 50604,
1243
+ "<|4.80|>": 50605,
1244
+ "<|4.82|>": 50606,
1245
+ "<|4.84|>": 50607,
1246
+ "<|4.86|>": 50608,
1247
+ "<|4.88|>": 50609,
1248
+ "<|4.90|>": 50610,
1249
+ "<|4.92|>": 50611,
1250
+ "<|4.94|>": 50612,
1251
+ "<|4.96|>": 50613,
1252
+ "<|4.98|>": 50614,
1253
+ "<|5.00|>": 50615,
1254
+ "<|5.02|>": 50616,
1255
+ "<|5.04|>": 50617,
1256
+ "<|5.06|>": 50618,
1257
+ "<|5.08|>": 50619,
1258
+ "<|5.10|>": 50620,
1259
+ "<|5.12|>": 50621,
1260
+ "<|5.14|>": 50622,
1261
+ "<|5.16|>": 50623,
1262
+ "<|5.18|>": 50624,
1263
+ "<|5.20|>": 50625,
1264
+ "<|5.22|>": 50626,
1265
+ "<|5.24|>": 50627,
1266
+ "<|5.26|>": 50628,
1267
+ "<|5.28|>": 50629,
1268
+ "<|5.30|>": 50630,
1269
+ "<|5.32|>": 50631,
1270
+ "<|5.34|>": 50632,
1271
+ "<|5.36|>": 50633,
1272
+ "<|5.38|>": 50634,
1273
+ "<|5.40|>": 50635,
1274
+ "<|5.42|>": 50636,
1275
+ "<|5.44|>": 50637,
1276
+ "<|5.46|>": 50638,
1277
+ "<|5.48|>": 50639,
1278
+ "<|5.50|>": 50640,
1279
+ "<|5.52|>": 50641,
1280
+ "<|5.54|>": 50642,
1281
+ "<|5.56|>": 50643,
1282
+ "<|5.58|>": 50644,
1283
+ "<|5.60|>": 50645,
1284
+ "<|5.62|>": 50646,
1285
+ "<|5.64|>": 50647,
1286
+ "<|5.66|>": 50648,
1287
+ "<|5.68|>": 50649,
1288
+ "<|5.70|>": 50650,
1289
+ "<|5.72|>": 50651,
1290
+ "<|5.74|>": 50652,
1291
+ "<|5.76|>": 50653,
1292
+ "<|5.78|>": 50654,
1293
+ "<|5.80|>": 50655,
1294
+ "<|5.82|>": 50656,
1295
+ "<|5.84|>": 50657,
1296
+ "<|5.86|>": 50658,
1297
+ "<|5.88|>": 50659,
1298
+ "<|5.90|>": 50660,
1299
+ "<|5.92|>": 50661,
1300
+ "<|5.94|>": 50662,
1301
+ "<|5.96|>": 50663,
1302
+ "<|5.98|>": 50664,
1303
+ "<|6.00|>": 50665,
1304
+ "<|6.02|>": 50666,
1305
+ "<|6.04|>": 50667,
1306
+ "<|6.06|>": 50668,
1307
+ "<|6.08|>": 50669,
1308
+ "<|6.10|>": 50670,
1309
+ "<|6.12|>": 50671,
1310
+ "<|6.14|>": 50672,
1311
+ "<|6.16|>": 50673,
1312
+ "<|6.18|>": 50674,
1313
+ "<|6.20|>": 50675,
1314
+ "<|6.22|>": 50676,
1315
+ "<|6.24|>": 50677,
1316
+ "<|6.26|>": 50678,
1317
+ "<|6.28|>": 50679,
1318
+ "<|6.30|>": 50680,
1319
+ "<|6.32|>": 50681,
1320
+ "<|6.34|>": 50682,
1321
+ "<|6.36|>": 50683,
1322
+ "<|6.38|>": 50684,
1323
+ "<|6.40|>": 50685,
1324
+ "<|6.42|>": 50686,
1325
+ "<|6.44|>": 50687,
1326
+ "<|6.46|>": 50688,
1327
+ "<|6.48|>": 50689,
1328
+ "<|6.50|>": 50690,
1329
+ "<|6.52|>": 50691,
1330
+ "<|6.54|>": 50692,
1331
+ "<|6.56|>": 50693,
1332
+ "<|6.58|>": 50694,
1333
+ "<|6.60|>": 50695,
1334
+ "<|6.62|>": 50696,
1335
+ "<|6.64|>": 50697,
1336
+ "<|6.66|>": 50698,
1337
+ "<|6.68|>": 50699,
1338
+ "<|6.70|>": 50700,
1339
+ "<|6.72|>": 50701,
1340
+ "<|6.74|>": 50702,
1341
+ "<|6.76|>": 50703,
1342
+ "<|6.78|>": 50704,
1343
+ "<|6.80|>": 50705,
1344
+ "<|6.82|>": 50706,
1345
+ "<|6.84|>": 50707,
1346
+ "<|6.86|>": 50708,
1347
+ "<|6.88|>": 50709,
1348
+ "<|6.90|>": 50710,
1349
+ "<|6.92|>": 50711,
1350
+ "<|6.94|>": 50712,
1351
+ "<|6.96|>": 50713,
1352
+ "<|6.98|>": 50714,
1353
+ "<|7.00|>": 50715,
1354
+ "<|7.02|>": 50716,
1355
+ "<|7.04|>": 50717,
1356
+ "<|7.06|>": 50718,
1357
+ "<|7.08|>": 50719,
1358
+ "<|7.10|>": 50720,
1359
+ "<|7.12|>": 50721,
1360
+ "<|7.14|>": 50722,
1361
+ "<|7.16|>": 50723,
1362
+ "<|7.18|>": 50724,
1363
+ "<|7.20|>": 50725,
1364
+ "<|7.22|>": 50726,
1365
+ "<|7.24|>": 50727,
1366
+ "<|7.26|>": 50728,
1367
+ "<|7.28|>": 50729,
1368
+ "<|7.30|>": 50730,
1369
+ "<|7.32|>": 50731,
1370
+ "<|7.34|>": 50732,
1371
+ "<|7.36|>": 50733,
1372
+ "<|7.38|>": 50734,
1373
+ "<|7.40|>": 50735,
1374
+ "<|7.42|>": 50736,
1375
+ "<|7.44|>": 50737,
1376
+ "<|7.46|>": 50738,
1377
+ "<|7.48|>": 50739,
1378
+ "<|7.50|>": 50740,
1379
+ "<|7.52|>": 50741,
1380
+ "<|7.54|>": 50742,
1381
+ "<|7.56|>": 50743,
1382
+ "<|7.58|>": 50744,
1383
+ "<|7.60|>": 50745,
1384
+ "<|7.62|>": 50746,
1385
+ "<|7.64|>": 50747,
1386
+ "<|7.66|>": 50748,
1387
+ "<|7.68|>": 50749,
1388
+ "<|7.70|>": 50750,
1389
+ "<|7.72|>": 50751,
1390
+ "<|7.74|>": 50752,
1391
+ "<|7.76|>": 50753,
1392
+ "<|7.78|>": 50754,
1393
+ "<|7.80|>": 50755,
1394
+ "<|7.82|>": 50756,
1395
+ "<|7.84|>": 50757,
1396
+ "<|7.86|>": 50758,
1397
+ "<|7.88|>": 50759,
1398
+ "<|7.90|>": 50760,
1399
+ "<|7.92|>": 50761,
1400
+ "<|7.94|>": 50762,
1401
+ "<|7.96|>": 50763,
1402
+ "<|7.98|>": 50764,
1403
+ "<|8.00|>": 50765,
1404
+ "<|8.02|>": 50766,
1405
+ "<|8.04|>": 50767,
1406
+ "<|8.06|>": 50768,
1407
+ "<|8.08|>": 50769,
1408
+ "<|8.10|>": 50770,
1409
+ "<|8.12|>": 50771,
1410
+ "<|8.14|>": 50772,
1411
+ "<|8.16|>": 50773,
1412
+ "<|8.18|>": 50774,
1413
+ "<|8.20|>": 50775,
1414
+ "<|8.22|>": 50776,
1415
+ "<|8.24|>": 50777,
1416
+ "<|8.26|>": 50778,
1417
+ "<|8.28|>": 50779,
1418
+ "<|8.30|>": 50780,
1419
+ "<|8.32|>": 50781,
1420
+ "<|8.34|>": 50782,
1421
+ "<|8.36|>": 50783,
1422
+ "<|8.38|>": 50784,
1423
+ "<|8.40|>": 50785,
1424
+ "<|8.42|>": 50786,
1425
+ "<|8.44|>": 50787,
1426
+ "<|8.46|>": 50788,
1427
+ "<|8.48|>": 50789,
1428
+ "<|8.50|>": 50790,
1429
+ "<|8.52|>": 50791,
1430
+ "<|8.54|>": 50792,
1431
+ "<|8.56|>": 50793,
1432
+ "<|8.58|>": 50794,
1433
+ "<|8.60|>": 50795,
1434
+ "<|8.62|>": 50796,
1435
+ "<|8.64|>": 50797,
1436
+ "<|8.66|>": 50798,
1437
+ "<|8.68|>": 50799,
1438
+ "<|8.70|>": 50800,
1439
+ "<|8.72|>": 50801,
1440
+ "<|8.74|>": 50802,
1441
+ "<|8.76|>": 50803,
1442
+ "<|8.78|>": 50804,
1443
+ "<|8.80|>": 50805,
1444
+ "<|8.82|>": 50806,
1445
+ "<|8.84|>": 50807,
1446
+ "<|8.86|>": 50808,
1447
+ "<|8.88|>": 50809,
1448
+ "<|8.90|>": 50810,
1449
+ "<|8.92|>": 50811,
1450
+ "<|8.94|>": 50812,
1451
+ "<|8.96|>": 50813,
1452
+ "<|8.98|>": 50814,
1453
+ "<|9.00|>": 50815,
1454
+ "<|9.02|>": 50816,
1455
+ "<|9.04|>": 50817,
1456
+ "<|9.06|>": 50818,
1457
+ "<|9.08|>": 50819,
1458
+ "<|9.10|>": 50820,
1459
+ "<|9.12|>": 50821,
1460
+ "<|9.14|>": 50822,
1461
+ "<|9.16|>": 50823,
1462
+ "<|9.18|>": 50824,
1463
+ "<|9.20|>": 50825,
1464
+ "<|9.22|>": 50826,
1465
+ "<|9.24|>": 50827,
1466
+ "<|9.26|>": 50828,
1467
+ "<|9.28|>": 50829,
1468
+ "<|9.30|>": 50830,
1469
+ "<|9.32|>": 50831,
1470
+ "<|9.34|>": 50832,
1471
+ "<|9.36|>": 50833,
1472
+ "<|9.38|>": 50834,
1473
+ "<|9.40|>": 50835,
1474
+ "<|9.42|>": 50836,
1475
+ "<|9.44|>": 50837,
1476
+ "<|9.46|>": 50838,
1477
+ "<|9.48|>": 50839,
1478
+ "<|9.50|>": 50840,
1479
+ "<|9.52|>": 50841,
1480
+ "<|9.54|>": 50842,
1481
+ "<|9.56|>": 50843,
1482
+ "<|9.58|>": 50844,
1483
+ "<|9.60|>": 50845,
1484
+ "<|9.62|>": 50846,
1485
+ "<|9.64|>": 50847,
1486
+ "<|9.66|>": 50848,
1487
+ "<|9.68|>": 50849,
1488
+ "<|9.70|>": 50850,
1489
+ "<|9.72|>": 50851,
1490
+ "<|9.74|>": 50852,
1491
+ "<|9.76|>": 50853,
1492
+ "<|9.78|>": 50854,
1493
+ "<|9.80|>": 50855,
1494
+ "<|9.82|>": 50856,
1495
+ "<|9.84|>": 50857,
1496
+ "<|9.86|>": 50858,
1497
+ "<|9.88|>": 50859,
1498
+ "<|9.90|>": 50860,
1499
+ "<|9.92|>": 50861,
1500
+ "<|9.94|>": 50862,
1501
+ "<|9.96|>": 50863,
1502
+ "<|9.98|>": 50864,
1503
+ "<|af|>": 50327,
1504
+ "<|am|>": 50334,
1505
+ "<|ar|>": 50272,
1506
+ "<|as|>": 50350,
1507
+ "<|az|>": 50304,
1508
+ "<|ba|>": 50355,
1509
+ "<|be|>": 50330,
1510
+ "<|bg|>": 50292,
1511
+ "<|bn|>": 50302,
1512
+ "<|bo|>": 50347,
1513
+ "<|br|>": 50309,
1514
+ "<|bs|>": 50315,
1515
+ "<|ca|>": 50270,
1516
+ "<|cs|>": 50283,
1517
+ "<|cy|>": 50297,
1518
+ "<|da|>": 50285,
1519
+ "<|de|>": 50261,
1520
+ "<|el|>": 50281,
1521
+ "<|endoftext|>": 50257,
1522
+ "<|en|>": 50259,
1523
+ "<|es|>": 50262,
1524
+ "<|et|>": 50307,
1525
+ "<|eu|>": 50310,
1526
+ "<|fa|>": 50300,
1527
+ "<|fi|>": 50277,
1528
+ "<|fo|>": 50338,
1529
+ "<|fr|>": 50265,
1530
+ "<|gl|>": 50319,
1531
+ "<|gu|>": 50333,
1532
+ "<|haw|>": 50352,
1533
+ "<|ha|>": 50354,
1534
+ "<|he|>": 50279,
1535
+ "<|hi|>": 50276,
1536
+ "<|hr|>": 50291,
1537
+ "<|ht|>": 50339,
1538
+ "<|hu|>": 50286,
1539
+ "<|hy|>": 50312,
1540
+ "<|id|>": 50275,
1541
+ "<|is|>": 50311,
1542
+ "<|it|>": 50274,
1543
+ "<|ja|>": 50266,
1544
+ "<|jw|>": 50356,
1545
+ "<|ka|>": 50329,
1546
+ "<|kk|>": 50316,
1547
+ "<|km|>": 50323,
1548
+ "<|kn|>": 50306,
1549
+ "<|ko|>": 50264,
1550
+ "<|la|>": 50294,
1551
+ "<|lb|>": 50345,
1552
+ "<|ln|>": 50353,
1553
+ "<|lo|>": 50336,
1554
+ "<|lt|>": 50293,
1555
+ "<|lv|>": 50301,
1556
+ "<|mg|>": 50349,
1557
+ "<|mi|>": 50295,
1558
+ "<|mk|>": 50308,
1559
+ "<|ml|>": 50296,
1560
+ "<|mn|>": 50314,
1561
+ "<|mr|>": 50320,
1562
+ "<|ms|>": 50282,
1563
+ "<|mt|>": 50343,
1564
+ "<|my|>": 50346,
1565
+ "<|ne|>": 50313,
1566
+ "<|nl|>": 50271,
1567
+ "<|nn|>": 50342,
1568
+ "<|nospeech|>": 50363,
1569
+ "<|notimestamps|>": 50364,
1570
+ "<|no|>": 50288,
1571
+ "<|oc|>": 50328,
1572
+ "<|pa|>": 50321,
1573
+ "<|pl|>": 50269,
1574
+ "<|ps|>": 50340,
1575
+ "<|pt|>": 50267,
1576
+ "<|ro|>": 50284,
1577
+ "<|ru|>": 50263,
1578
+ "<|sa|>": 50344,
1579
+ "<|sd|>": 50332,
1580
+ "<|si|>": 50322,
1581
+ "<|sk|>": 50298,
1582
+ "<|sl|>": 50305,
1583
+ "<|sn|>": 50324,
1584
+ "<|so|>": 50326,
1585
+ "<|sq|>": 50317,
1586
+ "<|sr|>": 50303,
1587
+ "<|startoflm|>": 50361,
1588
+ "<|startofprev|>": 50362,
1589
+ "<|startoftranscript|>": 50258,
1590
+ "<|su|>": 50357,
1591
+ "<|sv|>": 50273,
1592
+ "<|sw|>": 50318,
1593
+ "<|ta|>": 50287,
1594
+ "<|te|>": 50299,
1595
+ "<|tg|>": 50331,
1596
+ "<|th|>": 50289,
1597
+ "<|tk|>": 50341,
1598
+ "<|tl|>": 50348,
1599
+ "<|transcribe|>": 50360,
1600
+ "<|translate|>": 50359,
1601
+ "<|tr|>": 50268,
1602
+ "<|tt|>": 50351,
1603
+ "<|uk|>": 50280,
1604
+ "<|ur|>": 50290,
1605
+ "<|uz|>": 50337,
1606
+ "<|vi|>": 50278,
1607
+ "<|yi|>": 50335,
1608
+ "<|yo|>": 50325,
1609
+ "<|yue|>": 50358,
1610
+ "<|zh|>": 50260
1611
+ }
distil-large-v3-init/config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/whisper-large-v3",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "apply_spec_augment": false,
6
+ "architectures": [
7
+ "WhisperForConditionalGeneration"
8
+ ],
9
+ "attention_dropout": 0.0,
10
+ "begin_suppress_tokens": [
11
+ 220,
12
+ 50257
13
+ ],
14
+ "bos_token_id": 50257,
15
+ "classifier_proj_size": 256,
16
+ "d_model": 1280,
17
+ "decoder_attention_heads": 20,
18
+ "decoder_ffn_dim": 5120,
19
+ "decoder_layerdrop": 0.0,
20
+ "decoder_layers": 2,
21
+ "decoder_start_token_id": 50258,
22
+ "dropout": 0.0,
23
+ "encoder_attention_heads": 20,
24
+ "encoder_ffn_dim": 5120,
25
+ "encoder_layerdrop": 0.0,
26
+ "encoder_layers": 32,
27
+ "eos_token_id": 50257,
28
+ "init_std": 0.02,
29
+ "is_encoder_decoder": true,
30
+ "mask_feature_length": 10,
31
+ "mask_feature_min_masks": 0,
32
+ "mask_feature_prob": 0.0,
33
+ "mask_time_length": 10,
34
+ "mask_time_min_masks": 2,
35
+ "mask_time_prob": 0.05,
36
+ "max_length": 448,
37
+ "max_source_positions": 1500,
38
+ "max_target_positions": 448,
39
+ "median_filter_width": 7,
40
+ "model_type": "whisper",
41
+ "num_hidden_layers": 32,
42
+ "num_mel_bins": 128,
43
+ "pad_token_id": 50256,
44
+ "scale_embedding": false,
45
+ "torch_dtype": "float32",
46
+ "transformers_version": "4.40.1",
47
+ "use_cache": true,
48
+ "use_weighted_layer_sum": false,
49
+ "vocab_size": 51866
50
+ }
distil-large-v3-init/generation_config.json ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alignment_heads": [
3
+ [
4
+ 7,
5
+ 0
6
+ ],
7
+ [
8
+ 10,
9
+ 17
10
+ ],
11
+ [
12
+ 12,
13
+ 18
14
+ ],
15
+ [
16
+ 13,
17
+ 12
18
+ ],
19
+ [
20
+ 16,
21
+ 1
22
+ ],
23
+ [
24
+ 17,
25
+ 14
26
+ ],
27
+ [
28
+ 19,
29
+ 11
30
+ ],
31
+ [
32
+ 21,
33
+ 4
34
+ ],
35
+ [
36
+ 24,
37
+ 1
38
+ ],
39
+ [
40
+ 25,
41
+ 6
42
+ ]
43
+ ],
44
+ "begin_suppress_tokens": [
45
+ 220,
46
+ 50257
47
+ ],
48
+ "bos_token_id": 50257,
49
+ "decoder_start_token_id": 50258,
50
+ "eos_token_id": 50257,
51
+ "is_multilingual": true,
52
+ "lang_to_id": {
53
+ "<|af|>": 50327,
54
+ "<|am|>": 50334,
55
+ "<|ar|>": 50272,
56
+ "<|as|>": 50350,
57
+ "<|az|>": 50304,
58
+ "<|ba|>": 50355,
59
+ "<|be|>": 50330,
60
+ "<|bg|>": 50292,
61
+ "<|bn|>": 50302,
62
+ "<|bo|>": 50347,
63
+ "<|br|>": 50309,
64
+ "<|bs|>": 50315,
65
+ "<|ca|>": 50270,
66
+ "<|cs|>": 50283,
67
+ "<|cy|>": 50297,
68
+ "<|da|>": 50285,
69
+ "<|de|>": 50261,
70
+ "<|el|>": 50281,
71
+ "<|en|>": 50259,
72
+ "<|es|>": 50262,
73
+ "<|et|>": 50307,
74
+ "<|eu|>": 50310,
75
+ "<|fa|>": 50300,
76
+ "<|fi|>": 50277,
77
+ "<|fo|>": 50338,
78
+ "<|fr|>": 50265,
79
+ "<|gl|>": 50319,
80
+ "<|gu|>": 50333,
81
+ "<|haw|>": 50352,
82
+ "<|ha|>": 50354,
83
+ "<|he|>": 50279,
84
+ "<|hi|>": 50276,
85
+ "<|hr|>": 50291,
86
+ "<|ht|>": 50339,
87
+ "<|hu|>": 50286,
88
+ "<|hy|>": 50312,
89
+ "<|id|>": 50275,
90
+ "<|is|>": 50311,
91
+ "<|it|>": 50274,
92
+ "<|ja|>": 50266,
93
+ "<|jw|>": 50356,
94
+ "<|ka|>": 50329,
95
+ "<|kk|>": 50316,
96
+ "<|km|>": 50323,
97
+ "<|kn|>": 50306,
98
+ "<|ko|>": 50264,
99
+ "<|la|>": 50294,
100
+ "<|lb|>": 50345,
101
+ "<|ln|>": 50353,
102
+ "<|lo|>": 50336,
103
+ "<|lt|>": 50293,
104
+ "<|lv|>": 50301,
105
+ "<|mg|>": 50349,
106
+ "<|mi|>": 50295,
107
+ "<|mk|>": 50308,
108
+ "<|ml|>": 50296,
109
+ "<|mn|>": 50314,
110
+ "<|mr|>": 50320,
111
+ "<|ms|>": 50282,
112
+ "<|mt|>": 50343,
113
+ "<|my|>": 50346,
114
+ "<|ne|>": 50313,
115
+ "<|nl|>": 50271,
116
+ "<|nn|>": 50342,
117
+ "<|no|>": 50288,
118
+ "<|oc|>": 50328,
119
+ "<|pa|>": 50321,
120
+ "<|pl|>": 50269,
121
+ "<|ps|>": 50340,
122
+ "<|pt|>": 50267,
123
+ "<|ro|>": 50284,
124
+ "<|ru|>": 50263,
125
+ "<|sa|>": 50344,
126
+ "<|sd|>": 50332,
127
+ "<|si|>": 50322,
128
+ "<|sk|>": 50298,
129
+ "<|sl|>": 50305,
130
+ "<|sn|>": 50324,
131
+ "<|so|>": 50326,
132
+ "<|sq|>": 50317,
133
+ "<|sr|>": 50303,
134
+ "<|su|>": 50357,
135
+ "<|sv|>": 50273,
136
+ "<|sw|>": 50318,
137
+ "<|ta|>": 50287,
138
+ "<|te|>": 50299,
139
+ "<|tg|>": 50331,
140
+ "<|th|>": 50289,
141
+ "<|tk|>": 50341,
142
+ "<|tl|>": 50348,
143
+ "<|tr|>": 50268,
144
+ "<|tt|>": 50351,
145
+ "<|uk|>": 50280,
146
+ "<|ur|>": 50290,
147
+ "<|uz|>": 50337,
148
+ "<|vi|>": 50278,
149
+ "<|yi|>": 50335,
150
+ "<|yo|>": 50325,
151
+ "<|yue|>": 50358,
152
+ "<|zh|>": 50260
153
+ },
154
+ "max_initial_timestamp_index": 50,
155
+ "max_length": 448,
156
+ "no_timestamps_token_id": 50364,
157
+ "pad_token_id": 50257,
158
+ "prev_sot_token_id": 50362,
159
+ "return_timestamps": false,
160
+ "suppress_tokens": [
161
+ 1,
162
+ 2,
163
+ 7,
164
+ 8,
165
+ 9,
166
+ 10,
167
+ 14,
168
+ 25,
169
+ 26,
170
+ 27,
171
+ 28,
172
+ 29,
173
+ 31,
174
+ 58,
175
+ 59,
176
+ 60,
177
+ 61,
178
+ 62,
179
+ 63,
180
+ 90,
181
+ 91,
182
+ 92,
183
+ 93,
184
+ 359,
185
+ 503,
186
+ 522,
187
+ 542,
188
+ 873,
189
+ 893,
190
+ 902,
191
+ 918,
192
+ 922,
193
+ 931,
194
+ 1350,
195
+ 1853,
196
+ 1982,
197
+ 2460,
198
+ 2627,
199
+ 3246,
200
+ 3253,
201
+ 3268,
202
+ 3536,
203
+ 3846,
204
+ 3961,
205
+ 4183,
206
+ 4667,
207
+ 6585,
208
+ 6647,
209
+ 7273,
210
+ 9061,
211
+ 9383,
212
+ 10428,
213
+ 10929,
214
+ 11938,
215
+ 12033,
216
+ 12331,
217
+ 12562,
218
+ 13793,
219
+ 14157,
220
+ 14635,
221
+ 15265,
222
+ 15618,
223
+ 16553,
224
+ 16604,
225
+ 18362,
226
+ 18956,
227
+ 20075,
228
+ 21675,
229
+ 22520,
230
+ 26130,
231
+ 26161,
232
+ 26435,
233
+ 28279,
234
+ 29464,
235
+ 31650,
236
+ 32302,
237
+ 32470,
238
+ 36865,
239
+ 42863,
240
+ 47425,
241
+ 49870,
242
+ 50254,
243
+ 50258,
244
+ 50359,
245
+ 50360,
246
+ 50361,
247
+ 50362,
248
+ 50363
249
+ ],
250
+ "task_to_id": {
251
+ "transcribe": 50360,
252
+ "translate": 50359
253
+ },
254
+ "transformers_version": "4.40.1"
255
+ }
distil-large-v3-init/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
distil-large-v3-init/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c5ef44f7f59126b7b66937cc81d3194eb310f9af8b08512bbd6bd55fb0cda9f
3
+ size 3025686376
distil-large-v3-init/normalizer.json ADDED
@@ -0,0 +1,1742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "accessorise": "accessorize",
3
+ "accessorised": "accessorized",
4
+ "accessorises": "accessorizes",
5
+ "accessorising": "accessorizing",
6
+ "acclimatisation": "acclimatization",
7
+ "acclimatise": "acclimatize",
8
+ "acclimatised": "acclimatized",
9
+ "acclimatises": "acclimatizes",
10
+ "acclimatising": "acclimatizing",
11
+ "accoutrements": "accouterments",
12
+ "aeon": "eon",
13
+ "aeons": "eons",
14
+ "aerogramme": "aerogram",
15
+ "aerogrammes": "aerograms",
16
+ "aeroplane": "airplane",
17
+ "aeroplanes": "airplanes",
18
+ "aesthete": "esthete",
19
+ "aesthetes": "esthetes",
20
+ "aesthetic": "esthetic",
21
+ "aesthetically": "esthetically",
22
+ "aesthetics": "esthetics",
23
+ "aetiology": "etiology",
24
+ "ageing": "aging",
25
+ "aggrandisement": "aggrandizement",
26
+ "agonise": "agonize",
27
+ "agonised": "agonized",
28
+ "agonises": "agonizes",
29
+ "agonising": "agonizing",
30
+ "agonisingly": "agonizingly",
31
+ "almanack": "almanac",
32
+ "almanacks": "almanacs",
33
+ "aluminium": "aluminum",
34
+ "amortisable": "amortizable",
35
+ "amortisation": "amortization",
36
+ "amortisations": "amortizations",
37
+ "amortise": "amortize",
38
+ "amortised": "amortized",
39
+ "amortises": "amortizes",
40
+ "amortising": "amortizing",
41
+ "amphitheatre": "amphitheater",
42
+ "amphitheatres": "amphitheaters",
43
+ "anaemia": "anemia",
44
+ "anaemic": "anemic",
45
+ "anaesthesia": "anesthesia",
46
+ "anaesthetic": "anesthetic",
47
+ "anaesthetics": "anesthetics",
48
+ "anaesthetise": "anesthetize",
49
+ "anaesthetised": "anesthetized",
50
+ "anaesthetises": "anesthetizes",
51
+ "anaesthetising": "anesthetizing",
52
+ "anaesthetist": "anesthetist",
53
+ "anaesthetists": "anesthetists",
54
+ "anaesthetize": "anesthetize",
55
+ "anaesthetized": "anesthetized",
56
+ "anaesthetizes": "anesthetizes",
57
+ "anaesthetizing": "anesthetizing",
58
+ "analogue": "analog",
59
+ "analogues": "analogs",
60
+ "analyse": "analyze",
61
+ "analysed": "analyzed",
62
+ "analyses": "analyzes",
63
+ "analysing": "analyzing",
64
+ "anglicise": "anglicize",
65
+ "anglicised": "anglicized",
66
+ "anglicises": "anglicizes",
67
+ "anglicising": "anglicizing",
68
+ "annualised": "annualized",
69
+ "antagonise": "antagonize",
70
+ "antagonised": "antagonized",
71
+ "antagonises": "antagonizes",
72
+ "antagonising": "antagonizing",
73
+ "apologise": "apologize",
74
+ "apologised": "apologized",
75
+ "apologises": "apologizes",
76
+ "apologising": "apologizing",
77
+ "appal": "appall",
78
+ "appals": "appalls",
79
+ "appetiser": "appetizer",
80
+ "appetisers": "appetizers",
81
+ "appetising": "appetizing",
82
+ "appetisingly": "appetizingly",
83
+ "arbour": "arbor",
84
+ "arbours": "arbors",
85
+ "archaeologically": "archeologically",
86
+ "archaeologist": "archeologist",
87
+ "archaeologists": "archeologists",
88
+ "archaeology": "archeology</span>",
89
+ "archeological": "archaeological",
90
+ "ardour": "ardor",
91
+ "armour": "armor",
92
+ "armoured": "armored",
93
+ "armourer": "armorer",
94
+ "armourers": "armorers",
95
+ "armouries": "armories",
96
+ "armoury": "armory",
97
+ "artefact": "artifact",
98
+ "artefacts": "artifacts",
99
+ "authorise": "authorize",
100
+ "authorised": "authorized",
101
+ "authorises": "authorizes",
102
+ "authorising": "authorizing",
103
+ "axe": "ax",
104
+ "backpedalled": "backpedaled",
105
+ "backpedalling": "backpedaling",
106
+ "bannister": "banister",
107
+ "bannisters": "banisters",
108
+ "baptise": "baptize",
109
+ "baptised": "baptized",
110
+ "baptises": "baptizes",
111
+ "baptising": "baptizing",
112
+ "bastardise": "bastardize",
113
+ "bastardised": "bastardized",
114
+ "bastardises": "bastardizes",
115
+ "bastardising": "bastardizing",
116
+ "battleax": "battleaxe",
117
+ "baulk": "balk",
118
+ "baulked": "balked",
119
+ "baulking": "balking",
120
+ "baulks": "balks",
121
+ "bedevilled": "bedeviled",
122
+ "bedevilling": "bedeviling",
123
+ "behaviour": "behavior",
124
+ "behavioural": "behavioral",
125
+ "behaviourism": "behaviorism",
126
+ "behaviourist": "behaviorist",
127
+ "behaviourists": "behaviorists",
128
+ "behaviours": "behaviors",
129
+ "behove": "behoove",
130
+ "behoved": "behooved",
131
+ "behoves": "behooves",
132
+ "bejewelled": "bejeweled",
133
+ "belabour": "belabor",
134
+ "belaboured": "belabored",
135
+ "belabouring": "belaboring",
136
+ "belabours": "belabors",
137
+ "bevelled": "beveled",
138
+ "bevvies": "bevies",
139
+ "bevvy": "bevy",
140
+ "biassed": "biased",
141
+ "biassing": "biasing",
142
+ "bingeing": "binging",
143
+ "bougainvillaea": "bougainvillea",
144
+ "bougainvillaeas": "bougainvilleas",
145
+ "bowdlerise": "bowdlerize",
146
+ "bowdlerised": "bowdlerized",
147
+ "bowdlerises": "bowdlerizes",
148
+ "bowdlerising": "bowdlerizing",
149
+ "breathalyse": "breathalyze",
150
+ "breathalysed": "breathalyzed",
151
+ "breathalyser": "breathalyzer",
152
+ "breathalysers": "breathalyzers",
153
+ "breathalyses": "breathalyzes",
154
+ "breathalysing": "breathalyzing",
155
+ "brutalise": "brutalize",
156
+ "brutalised": "brutalized",
157
+ "brutalises": "brutalizes",
158
+ "brutalising": "brutalizing",
159
+ "busses": "buses",
160
+ "bussing": "busing",
161
+ "caesarean": "cesarean",
162
+ "caesareans": "cesareans",
163
+ "calibre": "caliber",
164
+ "calibres": "calibers",
165
+ "calliper": "caliper",
166
+ "callipers": "calipers",
167
+ "callisthenics": "calisthenics",
168
+ "canalise": "canalize",
169
+ "canalised": "canalized",
170
+ "canalises": "canalizes",
171
+ "canalising": "canalizing",
172
+ "cancelation": "cancellation",
173
+ "cancelations": "cancellations",
174
+ "cancelled": "canceled",
175
+ "cancelling": "canceling",
176
+ "candour": "candor",
177
+ "cannibalise": "cannibalize",
178
+ "cannibalised": "cannibalized",
179
+ "cannibalises": "cannibalizes",
180
+ "cannibalising": "cannibalizing",
181
+ "canonise": "canonize",
182
+ "canonised": "canonized",
183
+ "canonises": "canonizes",
184
+ "canonising": "canonizing",
185
+ "capitalise": "capitalize",
186
+ "capitalised": "capitalized",
187
+ "capitalises": "capitalizes",
188
+ "capitalising": "capitalizing",
189
+ "caramelise": "caramelize",
190
+ "caramelised": "caramelized",
191
+ "caramelises": "caramelizes",
192
+ "caramelising": "caramelizing",
193
+ "carbonise": "carbonize",
194
+ "carbonised": "carbonized",
195
+ "carbonises": "carbonizes",
196
+ "carbonising": "carbonizing",
197
+ "carolled": "caroled",
198
+ "carolling": "caroling",
199
+ "catalogue": "catalog",
200
+ "catalogued": "cataloged",
201
+ "catalogues": "catalogs",
202
+ "cataloguing": "cataloging",
203
+ "catalyse": "catalyze",
204
+ "catalysed": "catalyzed",
205
+ "catalyses": "catalyzes",
206
+ "catalysing": "catalyzing",
207
+ "categorise": "categorize",
208
+ "categorised": "categorized",
209
+ "categorises": "categorizes",
210
+ "categorising": "categorizing",
211
+ "cauterise": "cauterize",
212
+ "cauterised": "cauterized",
213
+ "cauterises": "cauterizes",
214
+ "cauterising": "cauterizing",
215
+ "cavilled": "caviled",
216
+ "cavilling": "caviling",
217
+ "centigramme": "centigram",
218
+ "centigrammes": "centigrams",
219
+ "centilitre": "centiliter",
220
+ "centilitres": "centiliters",
221
+ "centimetre": "centimeter",
222
+ "centimetres": "centimeters",
223
+ "centralise": "centralize",
224
+ "centralised": "centralized",
225
+ "centralises": "centralizes",
226
+ "centralising": "centralizing",
227
+ "centre": "center",
228
+ "centred": "centered",
229
+ "centrefold": "centerfold",
230
+ "centrefolds": "centerfolds",
231
+ "centrepiece": "centerpiece",
232
+ "centrepieces": "centerpieces",
233
+ "centres": "centers",
234
+ "channelled": "channeled",
235
+ "channelling": "channeling",
236
+ "characterise": "characterize",
237
+ "characterised": "characterized",
238
+ "characterises": "characterizes",
239
+ "characterising": "characterizing",
240
+ "cheque": "check",
241
+ "chequebook": "checkbook",
242
+ "chequebooks": "checkbooks",
243
+ "chequered": "checkered",
244
+ "cheques": "checks",
245
+ "chilli": "chili",
246
+ "chimaera": "chimera",
247
+ "chimaeras": "chimeras",
248
+ "chiselled": "chiseled",
249
+ "chiselling": "chiseling",
250
+ "circularise": "circularize",
251
+ "circularised": "circularized",
252
+ "circularises": "circularizes",
253
+ "circularising": "circularizing",
254
+ "civilise": "civilize",
255
+ "civilised": "civilized",
256
+ "civilises": "civilizes",
257
+ "civilising": "civilizing",
258
+ "clamour": "clamor",
259
+ "clamoured": "clamored",
260
+ "clamouring": "clamoring",
261
+ "clamours": "clamors",
262
+ "clangour": "clangor",
263
+ "clarinettist": "clarinetist",
264
+ "clarinettists": "clarinetists",
265
+ "collectivise": "collectivize",
266
+ "collectivised": "collectivized",
267
+ "collectivises": "collectivizes",
268
+ "collectivising": "collectivizing",
269
+ "colonisation": "colonization",
270
+ "colonise": "colonize",
271
+ "colonised": "colonized",
272
+ "coloniser": "colonizer",
273
+ "colonisers": "colonizers",
274
+ "colonises": "colonizes",
275
+ "colonising": "colonizing",
276
+ "colour": "color",
277
+ "colourant": "colorant",
278
+ "colourants": "colorants",
279
+ "coloured": "colored",
280
+ "coloureds": "coloreds",
281
+ "colourful": "colorful",
282
+ "colourfully": "colorfully",
283
+ "colouring": "coloring",
284
+ "colourize": "colorize",
285
+ "colourized": "colorized",
286
+ "colourizes": "colorizes",
287
+ "colourizing": "colorizing",
288
+ "colourless": "colorless",
289
+ "colours": "colors",
290
+ "commercialise": "commercialize",
291
+ "commercialised": "commercialized",
292
+ "commercialises": "commercializes",
293
+ "commercialising": "commercializing",
294
+ "compartmentalise": "compartmentalize",
295
+ "compartmentalised": "compartmentalized",
296
+ "compartmentalises": "compartmentalizes",
297
+ "compartmentalising": "compartmentalizing",
298
+ "computerise": "computerize",
299
+ "computerised": "computerized",
300
+ "computerises": "computerizes",
301
+ "computerising": "computerizing",
302
+ "conceptualise": "conceptualize",
303
+ "conceptualised": "conceptualized",
304
+ "conceptualises": "conceptualizes",
305
+ "conceptualising": "conceptualizing",
306
+ "connexion": "connection",
307
+ "connexions": "connections",
308
+ "contextualise": "contextualize",
309
+ "contextualised": "contextualized",
310
+ "contextualises": "contextualizes",
311
+ "contextualising": "contextualizing",
312
+ "cosier": "cozier",
313
+ "cosies": "cozies",
314
+ "cosiest": "coziest",
315
+ "cosily": "cozily",
316
+ "cosiness": "coziness",
317
+ "cosy": "cozy",
318
+ "councillor": "councilor",
319
+ "councillors": "councilors",
320
+ "counselled": "counseled",
321
+ "counselling": "counseling",
322
+ "counsellor": "counselor",
323
+ "counsellors": "counselors",
324
+ "crenelated": "crenellated",
325
+ "criminalise": "criminalize",
326
+ "criminalised": "criminalized",
327
+ "criminalises": "criminalizes",
328
+ "criminalising": "criminalizing",
329
+ "criticise": "criticize",
330
+ "criticised": "criticized",
331
+ "criticises": "criticizes",
332
+ "criticising": "criticizing",
333
+ "crueller": "crueler",
334
+ "cruellest": "cruelest",
335
+ "crystallisation": "crystallization",
336
+ "crystallise": "crystallize",
337
+ "crystallised": "crystallized",
338
+ "crystallises": "crystallizes",
339
+ "crystallising": "crystallizing",
340
+ "cudgelled": "cudgeled",
341
+ "cudgelling": "cudgeling",
342
+ "customise": "customize",
343
+ "customised": "customized",
344
+ "customises": "customizes",
345
+ "customising": "customizing",
346
+ "cypher": "cipher",
347
+ "cyphers": "ciphers",
348
+ "decentralisation": "decentralization",
349
+ "decentralise": "decentralize",
350
+ "decentralised": "decentralized",
351
+ "decentralises": "decentralizes",
352
+ "decentralising": "decentralizing",
353
+ "decriminalisation": "decriminalization",
354
+ "decriminalise": "decriminalize",
355
+ "decriminalised": "decriminalized",
356
+ "decriminalises": "decriminalizes",
357
+ "decriminalising": "decriminalizing",
358
+ "defence": "defense",
359
+ "defenceless": "defenseless",
360
+ "defences": "defenses",
361
+ "dehumanisation": "dehumanization",
362
+ "dehumanise": "dehumanize",
363
+ "dehumanised": "dehumanized",
364
+ "dehumanises": "dehumanizes",
365
+ "dehumanising": "dehumanizing",
366
+ "demeanour": "demeanor",
367
+ "demilitarisation": "demilitarization",
368
+ "demilitarise": "demilitarize",
369
+ "demilitarised": "demilitarized",
370
+ "demilitarises": "demilitarizes",
371
+ "demilitarising": "demilitarizing",
372
+ "demobilisation": "demobilization",
373
+ "demobilise": "demobilize",
374
+ "demobilised": "demobilized",
375
+ "demobilises": "demobilizes",
376
+ "demobilising": "demobilizing",
377
+ "democratisation": "democratization",
378
+ "democratise": "democratize",
379
+ "democratised": "democratized",
380
+ "democratises": "democratizes",
381
+ "democratising": "democratizing",
382
+ "demonise": "demonize",
383
+ "demonised": "demonized",
384
+ "demonises": "demonizes",
385
+ "demonising": "demonizing",
386
+ "demoralisation": "demoralization",
387
+ "demoralise": "demoralize",
388
+ "demoralised": "demoralized",
389
+ "demoralises": "demoralizes",
390
+ "demoralising": "demoralizing",
391
+ "denationalisation": "denationalization",
392
+ "denationalise": "denationalize",
393
+ "denationalised": "denationalized",
394
+ "denationalises": "denationalizes",
395
+ "denationalising": "denationalizing",
396
+ "deodorise": "deodorize",
397
+ "deodorised": "deodorized",
398
+ "deodorises": "deodorizes",
399
+ "deodorising": "deodorizing",
400
+ "depersonalise": "depersonalize",
401
+ "depersonalised": "depersonalized",
402
+ "depersonalises": "depersonalizes",
403
+ "depersonalising": "depersonalizing",
404
+ "deputise": "deputize",
405
+ "deputised": "deputized",
406
+ "deputises": "deputizes",
407
+ "deputising": "deputizing",
408
+ "desensitisation": "desensitization",
409
+ "desensitise": "desensitize",
410
+ "desensitised": "desensitized",
411
+ "desensitises": "desensitizes",
412
+ "desensitising": "desensitizing",
413
+ "destabilisation": "destabilization",
414
+ "destabilise": "destabilize",
415
+ "destabilised": "destabilized",
416
+ "destabilises": "destabilizes",
417
+ "destabilising": "destabilizing",
418
+ "dialled": "dialed",
419
+ "dialling": "dialing",
420
+ "dialogue": "dialog",
421
+ "dialogues": "dialogs",
422
+ "diarrhoea": "diarrhea",
423
+ "digitise": "digitize",
424
+ "digitised": "digitized",
425
+ "digitises": "digitizes",
426
+ "digitising": "digitizing",
427
+ "disc": "disk",
428
+ "discolour": "discolor",
429
+ "discoloured": "discolored",
430
+ "discolouring": "discoloring",
431
+ "discolours": "discolors",
432
+ "discs": "disks",
433
+ "disembowelled": "disemboweled",
434
+ "disembowelling": "disemboweling",
435
+ "disfavour": "disfavor",
436
+ "dishevelled": "disheveled",
437
+ "dishonour": "dishonor",
438
+ "dishonourable": "dishonorable",
439
+ "dishonourably": "dishonorably",
440
+ "dishonoured": "dishonored",
441
+ "dishonouring": "dishonoring",
442
+ "dishonours": "dishonors",
443
+ "disorganisation": "disorganization",
444
+ "disorganised": "disorganized",
445
+ "distil": "distill",
446
+ "distils": "distills",
447
+ "dramatisation": "dramatization",
448
+ "dramatisations": "dramatizations",
449
+ "dramatise": "dramatize",
450
+ "dramatised": "dramatized",
451
+ "dramatises": "dramatizes",
452
+ "dramatising": "dramatizing",
453
+ "draught": "draft",
454
+ "draughtboard": "draftboard",
455
+ "draughtboards": "draftboards",
456
+ "draughtier": "draftier",
457
+ "draughtiest": "draftiest",
458
+ "draughts": "drafts",
459
+ "draughtsman": "draftsman",
460
+ "draughtsmanship": "draftsmanship",
461
+ "draughtsmen": "draftsmen",
462
+ "draughtswoman": "draftswoman",
463
+ "draughtswomen": "draftswomen",
464
+ "draughty": "drafty",
465
+ "drivelled": "driveled",
466
+ "drivelling": "driveling",
467
+ "duelled": "dueled",
468
+ "duelling": "dueling",
469
+ "economise": "economize",
470
+ "economised": "economized",
471
+ "economises": "economizes",
472
+ "economising": "economizing",
473
+ "editorialise": "editorialize",
474
+ "editorialised": "editorialized",
475
+ "editorialises": "editorializes",
476
+ "editorialising": "editorializing",
477
+ "edoema": "edema",
478
+ "empathise": "empathize",
479
+ "empathised": "empathized",
480
+ "empathises": "empathizes",
481
+ "empathising": "empathizing",
482
+ "emphasise": "emphasize",
483
+ "emphasised": "emphasized",
484
+ "emphasises": "emphasizes",
485
+ "emphasising": "emphasizing",
486
+ "enamelled": "enameled",
487
+ "enamelling": "enameling",
488
+ "enamoured": "enamored",
489
+ "encyclopaedia": "encyclopedia",
490
+ "encyclopaedias": "encyclopedias",
491
+ "encyclopaedic": "encyclopedic",
492
+ "endeavour": "endeavor",
493
+ "endeavoured": "endeavored",
494
+ "endeavouring": "endeavoring",
495
+ "endeavours": "endeavors",
496
+ "energise": "energize",
497
+ "energised": "energized",
498
+ "energises": "energizes",
499
+ "energising": "energizing",
500
+ "enrol": "enroll",
501
+ "enrols": "enrolls",
502
+ "enthral": "enthrall",
503
+ "enthrals": "enthralls",
504
+ "epaulette": "epaulet",
505
+ "epaulettes": "epaulets",
506
+ "epicentre": "epicenter",
507
+ "epicentres": "epicenters",
508
+ "epilogue": "epilog",
509
+ "epilogues": "epilogs",
510
+ "epitomise": "epitomize",
511
+ "epitomised": "epitomized",
512
+ "epitomises": "epitomizes",
513
+ "epitomising": "epitomizing",
514
+ "equalisation": "equalization",
515
+ "equalise": "equalize",
516
+ "equalised": "equalized",
517
+ "equaliser": "equalizer",
518
+ "equalisers": "equalizers",
519
+ "equalises": "equalizes",
520
+ "equalising": "equalizing",
521
+ "eulogise": "eulogize",
522
+ "eulogised": "eulogized",
523
+ "eulogises": "eulogizes",
524
+ "eulogising": "eulogizing",
525
+ "evangelise": "evangelize",
526
+ "evangelised": "evangelized",
527
+ "evangelises": "evangelizes",
528
+ "evangelising": "evangelizing",
529
+ "exorcise": "exorcize",
530
+ "exorcised": "exorcized",
531
+ "exorcises": "exorcizes",
532
+ "exorcising": "exorcizing",
533
+ "extemporisation": "extemporization",
534
+ "extemporise": "extemporize",
535
+ "extemporised": "extemporized",
536
+ "extemporises": "extemporizes",
537
+ "extemporising": "extemporizing",
538
+ "externalisation": "externalization",
539
+ "externalisations": "externalizations",
540
+ "externalise": "externalize",
541
+ "externalised": "externalized",
542
+ "externalises": "externalizes",
543
+ "externalising": "externalizing",
544
+ "factorise": "factorize",
545
+ "factorised": "factorized",
546
+ "factorises": "factorizes",
547
+ "factorising": "factorizing",
548
+ "faecal": "fecal",
549
+ "faeces": "feces",
550
+ "familiarisation": "familiarization",
551
+ "familiarise": "familiarize",
552
+ "familiarised": "familiarized",
553
+ "familiarises": "familiarizes",
554
+ "familiarising": "familiarizing",
555
+ "fantasise": "fantasize",
556
+ "fantasised": "fantasized",
557
+ "fantasises": "fantasizes",
558
+ "fantasising": "fantasizing",
559
+ "favour": "favor",
560
+ "favourable": "favorable",
561
+ "favourably": "favorably",
562
+ "favoured": "favored",
563
+ "favouring": "favoring",
564
+ "favourite": "favorite",
565
+ "favourites": "favorites",
566
+ "favouritism": "favoritism",
567
+ "favours": "favors",
568
+ "feminise": "feminize",
569
+ "feminised": "feminized",
570
+ "feminises": "feminizes",
571
+ "feminising": "feminizing",
572
+ "fertilisation": "fertilization",
573
+ "fertilise": "fertilize",
574
+ "fertilised": "fertilized",
575
+ "fertiliser": "fertilizer",
576
+ "fertilisers": "fertilizers",
577
+ "fertilises": "fertilizes",
578
+ "fertilising": "fertilizing",
579
+ "fervour": "fervor",
580
+ "fibre": "fiber",
581
+ "fibreglass": "fiberglass",
582
+ "fibres": "fibers",
583
+ "fictionalisation": "fictionalization",
584
+ "fictionalisations": "fictionalizations",
585
+ "fictionalise": "fictionalize",
586
+ "fictionalised": "fictionalized",
587
+ "fictionalises": "fictionalizes",
588
+ "fictionalising": "fictionalizing",
589
+ "fillet": "filet",
590
+ "filleted": "fileted",
591
+ "filleting": "fileting",
592
+ "fillets": "filets",
593
+ "finalisation": "finalization",
594
+ "finalise": "finalize",
595
+ "finalised": "finalized",
596
+ "finalises": "finalizes",
597
+ "finalising": "finalizing",
598
+ "flautist": "flutist",
599
+ "flautists": "flutists",
600
+ "flavour": "flavor",
601
+ "flavoured": "flavored",
602
+ "flavouring": "flavoring",
603
+ "flavourings": "flavorings",
604
+ "flavourless": "flavorless",
605
+ "flavours": "flavors",
606
+ "flavoursome": "flavorsome",
607
+ "flyer / flier": "flier / flyer",
608
+ "foetal": "fetal",
609
+ "foetid": "fetid",
610
+ "foetus": "fetus",
611
+ "foetuses": "fetuses",
612
+ "formalisation": "formalization",
613
+ "formalise": "formalize",
614
+ "formalised": "formalized",
615
+ "formalises": "formalizes",
616
+ "formalising": "formalizing",
617
+ "fossilisation": "fossilization",
618
+ "fossilise": "fossilize",
619
+ "fossilised": "fossilized",
620
+ "fossilises": "fossilizes",
621
+ "fossilising": "fossilizing",
622
+ "fraternisation": "fraternization",
623
+ "fraternise": "fraternize",
624
+ "fraternised": "fraternized",
625
+ "fraternises": "fraternizes",
626
+ "fraternising": "fraternizing",
627
+ "fulfil": "fulfill",
628
+ "fulfilment": "fulfillment",
629
+ "fulfils": "fulfills",
630
+ "funnelled": "funneled",
631
+ "funnelling": "funneling",
632
+ "gage": "gauge",
633
+ "gaged": "gauged",
634
+ "gages": "gauges",
635
+ "gaging": "gauging",
636
+ "galvanise": "galvanize",
637
+ "galvanised": "galvanized",
638
+ "galvanises": "galvanizes",
639
+ "galvanising": "galvanizing",
640
+ "gambolled": "gamboled",
641
+ "gambolling": "gamboling",
642
+ "gaol": "jail",
643
+ "gaolbird": "jailbird",
644
+ "gaolbirds": "jailbirds",
645
+ "gaolbreak": "jailbreak",
646
+ "gaolbreaks": "jailbreaks",
647
+ "gaoled": "jailed",
648
+ "gaoler": "jailer",
649
+ "gaolers": "jailers",
650
+ "gaoling": "jailing",
651
+ "gaols": "jails",
652
+ "gasses": "gases",
653
+ "generalisation": "generalization",
654
+ "generalisations": "generalizations",
655
+ "generalise": "generalize",
656
+ "generalised": "generalized",
657
+ "generalises": "generalizes",
658
+ "generalising": "generalizing",
659
+ "ghettoise": "ghettoize",
660
+ "ghettoised": "ghettoized",
661
+ "ghettoises": "ghettoizes",
662
+ "ghettoising": "ghettoizing",
663
+ "gipsies": "gypsies",
664
+ "glamor": "glamour",
665
+ "glamorise": "glamorize",
666
+ "glamorised": "glamorized",
667
+ "glamorises": "glamorizes",
668
+ "glamorising": "glamorizing",
669
+ "globalisation": "globalization",
670
+ "globalise": "globalize",
671
+ "globalised": "globalized",
672
+ "globalises": "globalizes",
673
+ "globalising": "globalizing",
674
+ "glueing": "gluing",
675
+ "goitre": "goiter",
676
+ "goitres": "goiters",
677
+ "gonorrhoea": "gonorrhea",
678
+ "gramme": "gram",
679
+ "grammes": "grams",
680
+ "gravelled": "graveled",
681
+ "grey": "gray",
682
+ "greyed": "grayed",
683
+ "greying": "graying",
684
+ "greyish": "grayish",
685
+ "greyness": "grayness",
686
+ "greys": "grays",
687
+ "grovelled": "groveled",
688
+ "grovelling": "groveling",
689
+ "groyne": "groin",
690
+ "groynes": "groins",
691
+ "gruelling": "grueling",
692
+ "gruellingly": "gruelingly",
693
+ "gryphon": "griffin",
694
+ "gryphons": "griffins",
695
+ "gynaecological": "gynecological",
696
+ "gynaecologist": "gynecologist",
697
+ "gynaecologists": "gynecologists",
698
+ "gynaecology": "gynecology",
699
+ "haematological": "hematological",
700
+ "haematologist": "hematologist",
701
+ "haematologists": "hematologists",
702
+ "haematology": "hematology",
703
+ "haemoglobin": "hemoglobin",
704
+ "haemophilia": "hemophilia",
705
+ "haemophiliac": "hemophiliac",
706
+ "haemophiliacs": "hemophiliacs",
707
+ "haemorrhage": "hemorrhage",
708
+ "haemorrhaged": "hemorrhaged",
709
+ "haemorrhages": "hemorrhages",
710
+ "haemorrhaging": "hemorrhaging",
711
+ "haemorrhoids": "hemorrhoids",
712
+ "harbour": "harbor",
713
+ "harboured": "harbored",
714
+ "harbouring": "harboring",
715
+ "harbours": "harbors",
716
+ "harmonisation": "harmonization",
717
+ "harmonise": "harmonize",
718
+ "harmonised": "harmonized",
719
+ "harmonises": "harmonizes",
720
+ "harmonising": "harmonizing",
721
+ "homoeopath": "homeopath",
722
+ "homoeopathic": "homeopathic",
723
+ "homoeopaths": "homeopaths",
724
+ "homoeopathy": "homeopathy",
725
+ "homogenise": "homogenize",
726
+ "homogenised": "homogenized",
727
+ "homogenises": "homogenizes",
728
+ "homogenising": "homogenizing",
729
+ "honour": "honor",
730
+ "honourable": "honorable",
731
+ "honourably": "honorably",
732
+ "honoured": "honored",
733
+ "honouring": "honoring",
734
+ "honours": "honors",
735
+ "hospitalisation": "hospitalization",
736
+ "hospitalise": "hospitalize",
737
+ "hospitalised": "hospitalized",
738
+ "hospitalises": "hospitalizes",
739
+ "hospitalising": "hospitalizing",
740
+ "humanise": "humanize",
741
+ "humanised": "humanized",
742
+ "humanises": "humanizes",
743
+ "humanising": "humanizing",
744
+ "humour": "humor",
745
+ "humoured": "humored",
746
+ "humouring": "humoring",
747
+ "humourless": "humorless",
748
+ "humours": "humors",
749
+ "hybridise": "hybridize",
750
+ "hybridised": "hybridized",
751
+ "hybridises": "hybridizes",
752
+ "hybridising": "hybridizing",
753
+ "hypnotise": "hypnotize",
754
+ "hypnotised": "hypnotized",
755
+ "hypnotises": "hypnotizes",
756
+ "hypnotising": "hypnotizing",
757
+ "hypothesise": "hypothesize",
758
+ "hypothesised": "hypothesized",
759
+ "hypothesises": "hypothesizes",
760
+ "hypothesising": "hypothesizing",
761
+ "idealisation": "idealization",
762
+ "idealise": "idealize",
763
+ "idealised": "idealized",
764
+ "idealises": "idealizes",
765
+ "idealising": "idealizing",
766
+ "idolise": "idolize",
767
+ "idolised": "idolized",
768
+ "idolises": "idolizes",
769
+ "idolising": "idolizing",
770
+ "immobilisation": "immobilization",
771
+ "immobilise": "immobilize",
772
+ "immobilised": "immobilized",
773
+ "immobiliser": "immobilizer",
774
+ "immobilisers": "immobilizers",
775
+ "immobilises": "immobilizes",
776
+ "immobilising": "immobilizing",
777
+ "immortalise": "immortalize",
778
+ "immortalised": "immortalized",
779
+ "immortalises": "immortalizes",
780
+ "immortalising": "immortalizing",
781
+ "immunisation": "immunization",
782
+ "immunise": "immunize",
783
+ "immunised": "immunized",
784
+ "immunises": "immunizes",
785
+ "immunising": "immunizing",
786
+ "impanelled": "impaneled",
787
+ "impanelling": "impaneling",
788
+ "imperilled": "imperiled",
789
+ "imperilling": "imperiling",
790
+ "individualise": "individualize",
791
+ "individualised": "individualized",
792
+ "individualises": "individualizes",
793
+ "individualising": "individualizing",
794
+ "industrialise": "industrialize",
795
+ "industrialised": "industrialized",
796
+ "industrialises": "industrializes",
797
+ "industrialising": "industrializing",
798
+ "inflexion": "inflection",
799
+ "inflexions": "inflections",
800
+ "initialise": "initialize",
801
+ "initialised": "initialized",
802
+ "initialises": "initializes",
803
+ "initialising": "initializing",
804
+ "initialled": "initialed",
805
+ "initialling": "initialing",
806
+ "instal": "install",
807
+ "instalment": "installment",
808
+ "instalments": "installments",
809
+ "instals": "installs",
810
+ "instil": "instill",
811
+ "instils": "instills",
812
+ "institutionalisation": "institutionalization",
813
+ "institutionalise": "institutionalize",
814
+ "institutionalised": "institutionalized",
815
+ "institutionalises": "institutionalizes",
816
+ "institutionalising": "institutionalizing",
817
+ "intellectualise": "intellectualize",
818
+ "intellectualised": "intellectualized",
819
+ "intellectualises": "intellectualizes",
820
+ "intellectualising": "intellectualizing",
821
+ "internalisation": "internalization",
822
+ "internalise": "internalize",
823
+ "internalised": "internalized",
824
+ "internalises": "internalizes",
825
+ "internalising": "internalizing",
826
+ "internationalisation": "internationalization",
827
+ "internationalise": "internationalize",
828
+ "internationalised": "internationalized",
829
+ "internationalises": "internationalizes",
830
+ "internationalising": "internationalizing",
831
+ "ionisation": "ionization",
832
+ "ionise": "ionize",
833
+ "ionised": "ionized",
834
+ "ioniser": "ionizer",
835
+ "ionisers": "ionizers",
836
+ "ionises": "ionizes",
837
+ "ionising": "ionizing",
838
+ "italicise": "italicize",
839
+ "italicised": "italicized",
840
+ "italicises": "italicizes",
841
+ "italicising": "italicizing",
842
+ "itemise": "itemize",
843
+ "itemised": "itemized",
844
+ "itemises": "itemizes",
845
+ "itemising": "itemizing",
846
+ "jeopardise": "jeopardize",
847
+ "jeopardised": "jeopardized",
848
+ "jeopardises": "jeopardizes",
849
+ "jeopardising": "jeopardizing",
850
+ "jewelled": "jeweled",
851
+ "jeweller": "jeweler",
852
+ "jewellers": "jewelers",
853
+ "jewellery": "jewelry",
854
+ "judgement": "judgment",
855
+ "kilogramme": "kilogram",
856
+ "kilogrammes": "kilograms",
857
+ "kilometre": "kilometer",
858
+ "kilometres": "kilometers",
859
+ "labelled": "labeled",
860
+ "labelling": "labeling",
861
+ "labour": "labor",
862
+ "laboured": "labored",
863
+ "labourer": "laborer",
864
+ "labourers": "laborers",
865
+ "labouring": "laboring",
866
+ "labours": "labors",
867
+ "lacklustre": "lackluster",
868
+ "legalisation": "legalization",
869
+ "legalise": "legalize",
870
+ "legalised": "legalized",
871
+ "legalises": "legalizes",
872
+ "legalising": "legalizing",
873
+ "legitimise": "legitimize",
874
+ "legitimised": "legitimized",
875
+ "legitimises": "legitimizes",
876
+ "legitimising": "legitimizing",
877
+ "leukaemia": "leukemia",
878
+ "levelled": "leveled",
879
+ "leveller": "leveler",
880
+ "levellers": "levelers",
881
+ "levelling": "leveling",
882
+ "libelled": "libeled",
883
+ "libelling": "libeling",
884
+ "libellous": "libelous",
885
+ "liberalisation": "liberalization",
886
+ "liberalise": "liberalize",
887
+ "liberalised": "liberalized",
888
+ "liberalises": "liberalizes",
889
+ "liberalising": "liberalizing",
890
+ "licence": "license",
891
+ "licenced": "licensed",
892
+ "licences": "licenses",
893
+ "licencing": "licensing",
894
+ "likeable": "likable",
895
+ "lionisation": "lionization",
896
+ "lionise": "lionize",
897
+ "lionised": "lionized",
898
+ "lionises": "lionizes",
899
+ "lionising": "lionizing",
900
+ "liquidise": "liquidize",
901
+ "liquidised": "liquidized",
902
+ "liquidiser": "liquidizer",
903
+ "liquidisers": "liquidizers",
904
+ "liquidises": "liquidizes",
905
+ "liquidising": "liquidizing",
906
+ "litre": "liter",
907
+ "litres": "liters",
908
+ "localise": "localize",
909
+ "localised": "localized",
910
+ "localises": "localizes",
911
+ "localising": "localizing",
912
+ "louvre": "louver",
913
+ "louvred": "louvered",
914
+ "louvres": "louvers",
915
+ "lustre": "luster",
916
+ "magnetise": "magnetize",
917
+ "magnetised": "magnetized",
918
+ "magnetises": "magnetizes",
919
+ "magnetising": "magnetizing",
920
+ "manoeuvrability": "maneuverability",
921
+ "manoeuvrable": "maneuverable",
922
+ "manoeuvre": "maneuver",
923
+ "manoeuvred": "maneuvered",
924
+ "manoeuvres": "maneuvers",
925
+ "manoeuvring": "maneuvering",
926
+ "manoeuvrings": "maneuverings",
927
+ "marginalisation": "marginalization",
928
+ "marginalise": "marginalize",
929
+ "marginalised": "marginalized",
930
+ "marginalises": "marginalizes",
931
+ "marginalising": "marginalizing",
932
+ "marshalled": "marshaled",
933
+ "marshalling": "marshaling",
934
+ "marvelled": "marveled",
935
+ "marvelling": "marveling",
936
+ "marvellous": "marvelous",
937
+ "marvellously": "marvelously",
938
+ "materialisation": "materialization",
939
+ "materialise": "materialize",
940
+ "materialised": "materialized",
941
+ "materialises": "materializes",
942
+ "materialising": "materializing",
943
+ "maximisation": "maximization",
944
+ "maximise": "maximize",
945
+ "maximised": "maximized",
946
+ "maximises": "maximizes",
947
+ "maximising": "maximizing",
948
+ "meagre": "meager",
949
+ "mechanisation": "mechanization",
950
+ "mechanise": "mechanize",
951
+ "mechanised": "mechanized",
952
+ "mechanises": "mechanizes",
953
+ "mechanising": "mechanizing",
954
+ "mediaeval": "medieval",
955
+ "memorialise": "memorialize",
956
+ "memorialised": "memorialized",
957
+ "memorialises": "memorializes",
958
+ "memorialising": "memorializing",
959
+ "memorise": "memorize",
960
+ "memorised": "memorized",
961
+ "memorises": "memorizes",
962
+ "memorising": "memorizing",
963
+ "mesmerise": "mesmerize",
964
+ "mesmerised": "mesmerized",
965
+ "mesmerises": "mesmerizes",
966
+ "mesmerising": "mesmerizing",
967
+ "metabolise": "metabolize",
968
+ "metabolised": "metabolized",
969
+ "metabolises": "metabolizes",
970
+ "metabolising": "metabolizing",
971
+ "metre": "meter",
972
+ "metres": "meters",
973
+ "mhm": "hmm",
974
+ "micrometre": "micrometer",
975
+ "micrometres": "micrometers",
976
+ "militarise": "militarize",
977
+ "militarised": "militarized",
978
+ "militarises": "militarizes",
979
+ "militarising": "militarizing",
980
+ "milligramme": "milligram",
981
+ "milligrammes": "milligrams",
982
+ "millilitre": "milliliter",
983
+ "millilitres": "milliliters",
984
+ "millimetre": "millimeter",
985
+ "millimetres": "millimeters",
986
+ "miniaturisation": "miniaturization",
987
+ "miniaturise": "miniaturize",
988
+ "miniaturised": "miniaturized",
989
+ "miniaturises": "miniaturizes",
990
+ "miniaturising": "miniaturizing",
991
+ "minibusses": "minibuses",
992
+ "minimise": "minimize",
993
+ "minimised": "minimized",
994
+ "minimises": "minimizes",
995
+ "minimising": "minimizing",
996
+ "misbehaviour": "misbehavior",
997
+ "misdemeanour": "misdemeanor",
998
+ "misdemeanours": "misdemeanors",
999
+ "misspelt": "misspelled",
1000
+ "mitre": "miter",
1001
+ "mitres": "miters",
1002
+ "mm": "hmm",
1003
+ "mmm": "hmm",
1004
+ "mobilisation": "mobilization",
1005
+ "mobilise": "mobilize",
1006
+ "mobilised": "mobilized",
1007
+ "mobilises": "mobilizes",
1008
+ "mobilising": "mobilizing",
1009
+ "modelled": "modeled",
1010
+ "modeller": "modeler",
1011
+ "modellers": "modelers",
1012
+ "modelling": "modeling",
1013
+ "modernise": "modernize",
1014
+ "modernised": "modernized",
1015
+ "modernises": "modernizes",
1016
+ "modernising": "modernizing",
1017
+ "moisturise": "moisturize",
1018
+ "moisturised": "moisturized",
1019
+ "moisturiser": "moisturizer",
1020
+ "moisturisers": "moisturizers",
1021
+ "moisturises": "moisturizes",
1022
+ "moisturising": "moisturizing",
1023
+ "monologue": "monolog",
1024
+ "monologues": "monologs",
1025
+ "monopolisation": "monopolization",
1026
+ "monopolise": "monopolize",
1027
+ "monopolised": "monopolized",
1028
+ "monopolises": "monopolizes",
1029
+ "monopolising": "monopolizing",
1030
+ "moralise": "moralize",
1031
+ "moralised": "moralized",
1032
+ "moralises": "moralizes",
1033
+ "moralising": "moralizing",
1034
+ "motorised": "motorized",
1035
+ "mould": "mold",
1036
+ "moulded": "molded",
1037
+ "moulder": "molder",
1038
+ "mouldered": "moldered",
1039
+ "mouldering": "moldering",
1040
+ "moulders": "molders",
1041
+ "mouldier": "moldier",
1042
+ "mouldiest": "moldiest",
1043
+ "moulding": "molding",
1044
+ "mouldings": "moldings",
1045
+ "moulds": "molds",
1046
+ "mouldy": "moldy",
1047
+ "moult": "molt",
1048
+ "moulted": "molted",
1049
+ "moulting": "molting",
1050
+ "moults": "molts",
1051
+ "moustache": "mustache",
1052
+ "moustached": "mustached",
1053
+ "moustaches": "mustaches",
1054
+ "moustachioed": "mustachioed",
1055
+ "multicoloured": "multicolored",
1056
+ "nationalisation": "nationalization",
1057
+ "nationalisations": "nationalizations",
1058
+ "nationalise": "nationalize",
1059
+ "nationalised": "nationalized",
1060
+ "nationalises": "nationalizes",
1061
+ "nationalising": "nationalizing",
1062
+ "naturalisation": "naturalization",
1063
+ "naturalise": "naturalize",
1064
+ "naturalised": "naturalized",
1065
+ "naturalises": "naturalizes",
1066
+ "naturalising": "naturalizing",
1067
+ "neighbour": "neighbor",
1068
+ "neighbourhood": "neighborhood",
1069
+ "neighbourhoods": "neighborhoods",
1070
+ "neighbouring": "neighboring",
1071
+ "neighbourliness": "neighborliness",
1072
+ "neighbourly": "neighborly",
1073
+ "neighbours": "neighbors",
1074
+ "neutralisation": "neutralization",
1075
+ "neutralise": "neutralize",
1076
+ "neutralised": "neutralized",
1077
+ "neutralises": "neutralizes",
1078
+ "neutralising": "neutralizing",
1079
+ "normalisation": "normalization",
1080
+ "normalise": "normalize",
1081
+ "normalised": "normalized",
1082
+ "normalises": "normalizes",
1083
+ "normalising": "normalizing",
1084
+ "odour": "odor",
1085
+ "odourless": "odorless",
1086
+ "odours": "odors",
1087
+ "oesophagus": "esophagus",
1088
+ "oesophaguses": "esophaguses",
1089
+ "oestrogen": "estrogen",
1090
+ "offence": "offense",
1091
+ "offences": "offenses",
1092
+ "omelette": "omelet",
1093
+ "omelettes": "omelets",
1094
+ "optimise": "optimize",
1095
+ "optimised": "optimized",
1096
+ "optimises": "optimizes",
1097
+ "optimising": "optimizing",
1098
+ "organisation": "organization",
1099
+ "organisational": "organizational",
1100
+ "organisations": "organizations",
1101
+ "organise": "organize",
1102
+ "organised": "organized",
1103
+ "organiser": "organizer",
1104
+ "organisers": "organizers",
1105
+ "organises": "organizes",
1106
+ "organising": "organizing",
1107
+ "orthopaedic": "orthopedic",
1108
+ "orthopaedics": "orthopedics",
1109
+ "ostracise": "ostracize",
1110
+ "ostracised": "ostracized",
1111
+ "ostracises": "ostracizes",
1112
+ "ostracising": "ostracizing",
1113
+ "outmanoeuvre": "outmaneuver",
1114
+ "outmanoeuvred": "outmaneuvered",
1115
+ "outmanoeuvres": "outmaneuvers",
1116
+ "outmanoeuvring": "outmaneuvering",
1117
+ "overemphasise": "overemphasize",
1118
+ "overemphasised": "overemphasized",
1119
+ "overemphasises": "overemphasizes",
1120
+ "overemphasising": "overemphasizing",
1121
+ "oxidisation": "oxidization",
1122
+ "oxidise": "oxidize",
1123
+ "oxidised": "oxidized",
1124
+ "oxidises": "oxidizes",
1125
+ "oxidising": "oxidizing",
1126
+ "paederast": "pederast",
1127
+ "paederasts": "pederasts",
1128
+ "paediatric": "pediatric",
1129
+ "paediatrician": "pediatrician",
1130
+ "paediatricians": "pediatricians",
1131
+ "paediatrics": "pediatrics",
1132
+ "paedophile": "pedophile",
1133
+ "paedophiles": "pedophiles",
1134
+ "paedophilia": "pedophilia",
1135
+ "palaeolithic": "paleolithic",
1136
+ "palaeontologist": "paleontologist",
1137
+ "palaeontologists": "paleontologists",
1138
+ "palaeontology": "paleontology",
1139
+ "panelled": "paneled",
1140
+ "panelling": "paneling",
1141
+ "panellist": "panelist",
1142
+ "panellists": "panelists",
1143
+ "paralyse": "paralyze",
1144
+ "paralysed": "paralyzed",
1145
+ "paralyses": "paralyzes",
1146
+ "paralysing": "paralyzing",
1147
+ "parcelled": "parceled",
1148
+ "parcelling": "parceling",
1149
+ "parlour": "parlor",
1150
+ "parlours": "parlors",
1151
+ "particularise": "particularize",
1152
+ "particularised": "particularized",
1153
+ "particularises": "particularizes",
1154
+ "particularising": "particularizing",
1155
+ "passivisation": "passivization",
1156
+ "passivise": "passivize",
1157
+ "passivised": "passivized",
1158
+ "passivises": "passivizes",
1159
+ "passivising": "passivizing",
1160
+ "pasteurisation": "pasteurization",
1161
+ "pasteurise": "pasteurize",
1162
+ "pasteurised": "pasteurized",
1163
+ "pasteurises": "pasteurizes",
1164
+ "pasteurising": "pasteurizing",
1165
+ "patronise": "patronize",
1166
+ "patronised": "patronized",
1167
+ "patronises": "patronizes",
1168
+ "patronising": "patronizing",
1169
+ "patronisingly": "patronizingly",
1170
+ "pedalled": "pedaled",
1171
+ "pedalling": "pedaling",
1172
+ "pedestrianisation": "pedestrianization",
1173
+ "pedestrianise": "pedestrianize",
1174
+ "pedestrianised": "pedestrianized",
1175
+ "pedestrianises": "pedestrianizes",
1176
+ "pedestrianising": "pedestrianizing",
1177
+ "penalise": "penalize",
1178
+ "penalised": "penalized",
1179
+ "penalises": "penalizes",
1180
+ "penalising": "penalizing",
1181
+ "pencilled": "penciled",
1182
+ "pencilling": "penciling",
1183
+ "personalise": "personalize",
1184
+ "personalised": "personalized",
1185
+ "personalises": "personalizes",
1186
+ "personalising": "personalizing",
1187
+ "pharmacopoeia": "pharmacopeia",
1188
+ "pharmacopoeias": "pharmacopeias",
1189
+ "philosophise": "philosophize",
1190
+ "philosophised": "philosophized",
1191
+ "philosophises": "philosophizes",
1192
+ "philosophising": "philosophizing",
1193
+ "philtre": "filter",
1194
+ "philtres": "filters",
1195
+ "phoney": "phony",
1196
+ "plagiarise": "plagiarize",
1197
+ "plagiarised": "plagiarized",
1198
+ "plagiarises": "plagiarizes",
1199
+ "plagiarising": "plagiarizing",
1200
+ "plough": "plow",
1201
+ "ploughed": "plowed",
1202
+ "ploughing": "plowing",
1203
+ "ploughman": "plowman",
1204
+ "ploughmen": "plowmen",
1205
+ "ploughs": "plows",
1206
+ "ploughshare": "plowshare",
1207
+ "ploughshares": "plowshares",
1208
+ "polarisation": "polarization",
1209
+ "polarise": "polarize",
1210
+ "polarised": "polarized",
1211
+ "polarises": "polarizes",
1212
+ "polarising": "polarizing",
1213
+ "politicisation": "politicization",
1214
+ "politicise": "politicize",
1215
+ "politicised": "politicized",
1216
+ "politicises": "politicizes",
1217
+ "politicising": "politicizing",
1218
+ "popularisation": "popularization",
1219
+ "popularise": "popularize",
1220
+ "popularised": "popularized",
1221
+ "popularises": "popularizes",
1222
+ "popularising": "popularizing",
1223
+ "pouffe": "pouf",
1224
+ "pouffes": "poufs",
1225
+ "practise": "practice",
1226
+ "practised": "practiced",
1227
+ "practises": "practices",
1228
+ "practising": "practicing",
1229
+ "praesidium": "presidium",
1230
+ "praesidiums": "presidiums",
1231
+ "pressurisation": "pressurization",
1232
+ "pressurise": "pressurize",
1233
+ "pressurised": "pressurized",
1234
+ "pressurises": "pressurizes",
1235
+ "pressurising": "pressurizing",
1236
+ "pretence": "pretense",
1237
+ "pretences": "pretenses",
1238
+ "primaeval": "primeval",
1239
+ "prioritisation": "prioritization",
1240
+ "prioritise": "prioritize",
1241
+ "prioritised": "prioritized",
1242
+ "prioritises": "prioritizes",
1243
+ "prioritising": "prioritizing",
1244
+ "privatisation": "privatization",
1245
+ "privatisations": "privatizations",
1246
+ "privatise": "privatize",
1247
+ "privatised": "privatized",
1248
+ "privatises": "privatizes",
1249
+ "privatising": "privatizing",
1250
+ "professionalisation": "professionalization",
1251
+ "professionalise": "professionalize",
1252
+ "professionalised": "professionalized",
1253
+ "professionalises": "professionalizes",
1254
+ "professionalising": "professionalizing",
1255
+ "programme": "program",
1256
+ "programmes": "programs",
1257
+ "prologue": "prolog",
1258
+ "prologues": "prologs",
1259
+ "propagandise": "propagandize",
1260
+ "propagandised": "propagandized",
1261
+ "propagandises": "propagandizes",
1262
+ "propagandising": "propagandizing",
1263
+ "proselytise": "proselytize",
1264
+ "proselytised": "proselytized",
1265
+ "proselytiser": "proselytizer",
1266
+ "proselytisers": "proselytizers",
1267
+ "proselytises": "proselytizes",
1268
+ "proselytising": "proselytizing",
1269
+ "psychoanalyse": "psychoanalyze",
1270
+ "psychoanalysed": "psychoanalyzed",
1271
+ "psychoanalyses": "psychoanalyzes",
1272
+ "psychoanalysing": "psychoanalyzing",
1273
+ "publicise": "publicize",
1274
+ "publicised": "publicized",
1275
+ "publicises": "publicizes",
1276
+ "publicising": "publicizing",
1277
+ "pulverisation": "pulverization",
1278
+ "pulverise": "pulverize",
1279
+ "pulverised": "pulverized",
1280
+ "pulverises": "pulverizes",
1281
+ "pulverising": "pulverizing",
1282
+ "pummelled": "pummel",
1283
+ "pummelling": "pummeled",
1284
+ "pyjama": "pajama",
1285
+ "pyjamas": "pajamas",
1286
+ "pzazz": "pizzazz",
1287
+ "quarrelled": "quarreled",
1288
+ "quarrelling": "quarreling",
1289
+ "radicalise": "radicalize",
1290
+ "radicalised": "radicalized",
1291
+ "radicalises": "radicalizes",
1292
+ "radicalising": "radicalizing",
1293
+ "rancour": "rancor",
1294
+ "randomise": "randomize",
1295
+ "randomised": "randomized",
1296
+ "randomises": "randomizes",
1297
+ "randomising": "randomizing",
1298
+ "rationalisation": "rationalization",
1299
+ "rationalisations": "rationalizations",
1300
+ "rationalise": "rationalize",
1301
+ "rationalised": "rationalized",
1302
+ "rationalises": "rationalizes",
1303
+ "rationalising": "rationalizing",
1304
+ "ravelled": "raveled",
1305
+ "ravelling": "raveling",
1306
+ "realisable": "realizable",
1307
+ "realisation": "realization",
1308
+ "realisations": "realizations",
1309
+ "realise": "realize",
1310
+ "realised": "realized",
1311
+ "realises": "realizes",
1312
+ "realising": "realizing",
1313
+ "recognisable": "recognizable",
1314
+ "recognisably": "recognizably",
1315
+ "recognisance": "recognizance",
1316
+ "recognise": "recognize",
1317
+ "recognised": "recognized",
1318
+ "recognises": "recognizes",
1319
+ "recognising": "recognizing",
1320
+ "reconnoitre": "reconnoiter",
1321
+ "reconnoitred": "reconnoitered",
1322
+ "reconnoitres": "reconnoiters",
1323
+ "reconnoitring": "reconnoitering",
1324
+ "refuelled": "refueled",
1325
+ "refuelling": "refueling",
1326
+ "regularisation": "regularization",
1327
+ "regularise": "regularize",
1328
+ "regularised": "regularized",
1329
+ "regularises": "regularizes",
1330
+ "regularising": "regularizing",
1331
+ "remodelled": "remodeled",
1332
+ "remodelling": "remodeling",
1333
+ "remould": "remold",
1334
+ "remoulded": "remolded",
1335
+ "remoulding": "remolding",
1336
+ "remoulds": "remolds",
1337
+ "reorganisation": "reorganization",
1338
+ "reorganisations": "reorganizations",
1339
+ "reorganise": "reorganize",
1340
+ "reorganised": "reorganized",
1341
+ "reorganises": "reorganizes",
1342
+ "reorganising": "reorganizing",
1343
+ "revelled": "reveled",
1344
+ "reveller": "reveler",
1345
+ "revellers": "revelers",
1346
+ "revelling": "reveling",
1347
+ "revitalise": "revitalize",
1348
+ "revitalised": "revitalized",
1349
+ "revitalises": "revitalizes",
1350
+ "revitalising": "revitalizing",
1351
+ "revolutionise": "revolutionize",
1352
+ "revolutionised": "revolutionized",
1353
+ "revolutionises": "revolutionizes",
1354
+ "revolutionising": "revolutionizing",
1355
+ "rhapsodise": "rhapsodize",
1356
+ "rhapsodised": "rhapsodized",
1357
+ "rhapsodises": "rhapsodizes",
1358
+ "rhapsodising": "rhapsodizing",
1359
+ "rigour": "rigor",
1360
+ "rigours": "rigors",
1361
+ "ritualised": "ritualized",
1362
+ "rivalled": "rivaled",
1363
+ "rivalling": "rivaling",
1364
+ "romanticise": "romanticize",
1365
+ "romanticised": "romanticized",
1366
+ "romanticises": "romanticizes",
1367
+ "romanticising": "romanticizing",
1368
+ "rumour": "rumor",
1369
+ "rumoured": "rumored",
1370
+ "rumours": "rumors",
1371
+ "sabre": "saber",
1372
+ "sabres": "sabers",
1373
+ "saltpetre": "saltpeter",
1374
+ "sanitise": "sanitize",
1375
+ "sanitised": "sanitized",
1376
+ "sanitises": "sanitizes",
1377
+ "sanitising": "sanitizing",
1378
+ "satirise": "satirize",
1379
+ "satirised": "satirized",
1380
+ "satirises": "satirizes",
1381
+ "satirising": "satirizing",
1382
+ "saviour": "savior",
1383
+ "saviours": "saviors",
1384
+ "savour": "savor",
1385
+ "savoured": "savored",
1386
+ "savouries": "savories",
1387
+ "savouring": "savoring",
1388
+ "savours": "savors",
1389
+ "savoury": "savory",
1390
+ "scandalise": "scandalize",
1391
+ "scandalised": "scandalized",
1392
+ "scandalises": "scandalizes",
1393
+ "scandalising": "scandalizing",
1394
+ "sceptic": "skeptic",
1395
+ "sceptical": "skeptical",
1396
+ "sceptically": "skeptically",
1397
+ "scepticism": "skepticism",
1398
+ "sceptics": "skeptics",
1399
+ "sceptre": "scepter",
1400
+ "sceptres": "scepters",
1401
+ "scrutinise": "scrutinize",
1402
+ "scrutinised": "scrutinized",
1403
+ "scrutinises": "scrutinizes",
1404
+ "scrutinising": "scrutinizing",
1405
+ "secularisation": "secularization",
1406
+ "secularise": "secularize",
1407
+ "secularised": "secularized",
1408
+ "secularises": "secularizes",
1409
+ "secularising": "secularizing",
1410
+ "sensationalise": "sensationalize",
1411
+ "sensationalised": "sensationalized",
1412
+ "sensationalises": "sensationalizes",
1413
+ "sensationalising": "sensationalizing",
1414
+ "sensitise": "sensitize",
1415
+ "sensitised": "sensitized",
1416
+ "sensitises": "sensitizes",
1417
+ "sensitising": "sensitizing",
1418
+ "sentimentalise": "sentimentalize",
1419
+ "sentimentalised": "sentimentalized",
1420
+ "sentimentalises": "sentimentalizes",
1421
+ "sentimentalising": "sentimentalizing",
1422
+ "sepulchre": "sepulcher",
1423
+ "sepulchres": "sepulchers",
1424
+ "serialisation": "serialization",
1425
+ "serialisations": "serializations",
1426
+ "serialise": "serialize",
1427
+ "serialised": "serialized",
1428
+ "serialises": "serializes",
1429
+ "serialising": "serializing",
1430
+ "sermonise": "sermonize",
1431
+ "sermonised": "sermonized",
1432
+ "sermonises": "sermonizes",
1433
+ "sermonising": "sermonizing",
1434
+ "sheikh": "sheik",
1435
+ "shovelled": "shoveled",
1436
+ "shovelling": "shoveling",
1437
+ "shrivelled": "shriveled",
1438
+ "shrivelling": "shriveling",
1439
+ "signalise": "signalize",
1440
+ "signalised": "signalized",
1441
+ "signalises": "signalizes",
1442
+ "signalising": "signalizing",
1443
+ "signalled": "signaled",
1444
+ "signalling": "signaling",
1445
+ "smoulder": "smolder",
1446
+ "smouldered": "smoldered",
1447
+ "smouldering": "smoldering",
1448
+ "smoulders": "smolders",
1449
+ "snivelled": "sniveled",
1450
+ "snivelling": "sniveling",
1451
+ "snorkelled": "snorkeled",
1452
+ "snorkelling": "snorkeling",
1453
+ "snowplough": "snowplow",
1454
+ "snowploughs": "snowplow",
1455
+ "socialisation": "socialization",
1456
+ "socialise": "socialize",
1457
+ "socialised": "socialized",
1458
+ "socialises": "socializes",
1459
+ "socialising": "socializing",
1460
+ "sodomise": "sodomize",
1461
+ "sodomised": "sodomized",
1462
+ "sodomises": "sodomizes",
1463
+ "sodomising": "sodomizing",
1464
+ "solemnise": "solemnize",
1465
+ "solemnised": "solemnized",
1466
+ "solemnises": "solemnizes",
1467
+ "solemnising": "solemnizing",
1468
+ "sombre": "somber",
1469
+ "specialisation": "specialization",
1470
+ "specialisations": "specializations",
1471
+ "specialise": "specialize",
1472
+ "specialised": "specialized",
1473
+ "specialises": "specializes",
1474
+ "specialising": "specializing",
1475
+ "spectre": "specter",
1476
+ "spectres": "specters",
1477
+ "spiralled": "spiraled",
1478
+ "spiralling": "spiraling",
1479
+ "splendour": "splendor",
1480
+ "splendours": "splendors",
1481
+ "squirrelled": "squirreled",
1482
+ "squirrelling": "squirreling",
1483
+ "stabilisation": "stabilization",
1484
+ "stabilise": "stabilize",
1485
+ "stabilised": "stabilized",
1486
+ "stabiliser": "stabilizer",
1487
+ "stabilisers": "stabilizers",
1488
+ "stabilises": "stabilizes",
1489
+ "stabilising": "stabilizing",
1490
+ "standardisation": "standardization",
1491
+ "standardise": "standardize",
1492
+ "standardised": "standardized",
1493
+ "standardises": "standardizes",
1494
+ "standardising": "standardizing",
1495
+ "stencilled": "stenciled",
1496
+ "stencilling": "stenciling",
1497
+ "sterilisation": "sterilization",
1498
+ "sterilisations": "sterilizations",
1499
+ "sterilise": "sterilize",
1500
+ "sterilised": "sterilized",
1501
+ "steriliser": "sterilizer",
1502
+ "sterilisers": "sterilizers",
1503
+ "sterilises": "sterilizes",
1504
+ "sterilising": "sterilizing",
1505
+ "stigmatisation": "stigmatization",
1506
+ "stigmatise": "stigmatize",
1507
+ "stigmatised": "stigmatized",
1508
+ "stigmatises": "stigmatizes",
1509
+ "stigmatising": "stigmatizing",
1510
+ "storey": "story",
1511
+ "storeys": "stories",
1512
+ "subsidisation": "subsidization",
1513
+ "subsidise": "subsidize",
1514
+ "subsidised": "subsidized",
1515
+ "subsidiser": "subsidizer",
1516
+ "subsidisers": "subsidizers",
1517
+ "subsidises": "subsidizes",
1518
+ "subsidising": "subsidizing",
1519
+ "succour": "succor",
1520
+ "succoured": "succored",
1521
+ "succouring": "succoring",
1522
+ "succours": "succors",
1523
+ "sulphate": "sulfate",
1524
+ "sulphates": "sulfates",
1525
+ "sulphide": "sulfide",
1526
+ "sulphides": "sulfides",
1527
+ "sulphur": "sulfur",
1528
+ "sulphurous": "sulfurous",
1529
+ "summarise": "summarize",
1530
+ "summarised": "summarized",
1531
+ "summarises": "summarizes",
1532
+ "summarising": "summarizing",
1533
+ "swivelled": "swiveled",
1534
+ "swivelling": "swiveling",
1535
+ "symbolise": "symbolize",
1536
+ "symbolised": "symbolized",
1537
+ "symbolises": "symbolizes",
1538
+ "symbolising": "symbolizing",
1539
+ "sympathise": "sympathize",
1540
+ "sympathised": "sympathized",
1541
+ "sympathiser": "sympathizer",
1542
+ "sympathisers": "sympathizers",
1543
+ "sympathises": "sympathizes",
1544
+ "sympathising": "sympathizing",
1545
+ "synchronisation": "synchronization",
1546
+ "synchronise": "synchronize",
1547
+ "synchronised": "synchronized",
1548
+ "synchronises": "synchronizes",
1549
+ "synchronising": "synchronizing",
1550
+ "synthesise": "synthesize",
1551
+ "synthesised": "synthesized",
1552
+ "synthesiser": "synthesizer",
1553
+ "synthesisers": "synthesizers",
1554
+ "synthesises": "synthesizes",
1555
+ "synthesising": "synthesizing",
1556
+ "syphon": "siphon",
1557
+ "syphoned": "siphoned",
1558
+ "syphoning": "siphoning",
1559
+ "syphons": "siphons",
1560
+ "systematisation": "systematization",
1561
+ "systematise": "systematize",
1562
+ "systematised": "systematized",
1563
+ "systematises": "systematizes",
1564
+ "systematising": "systematizing",
1565
+ "tantalise": "tantalize",
1566
+ "tantalised": "tantalized",
1567
+ "tantalises": "tantalizes",
1568
+ "tantalising": "tantalizing",
1569
+ "tantalisingly": "tantalizingly",
1570
+ "tasselled": "tasseled",
1571
+ "technicolour": "technicolor",
1572
+ "temporise": "temporize",
1573
+ "temporised": "temporized",
1574
+ "temporises": "temporizes",
1575
+ "temporising": "temporizing",
1576
+ "tenderise": "tenderize",
1577
+ "tenderised": "tenderized",
1578
+ "tenderises": "tenderizes",
1579
+ "tenderising": "tenderizing",
1580
+ "terrorise": "terrorize",
1581
+ "terrorised": "terrorized",
1582
+ "terrorises": "terrorizes",
1583
+ "terrorising": "terrorizing",
1584
+ "theatre": "theater",
1585
+ "theatregoer": "theatergoer",
1586
+ "theatregoers": "theatergoers",
1587
+ "theatres": "theaters",
1588
+ "theorise": "theorize",
1589
+ "theorised": "theorized",
1590
+ "theorises": "theorizes",
1591
+ "theorising": "theorizing",
1592
+ "tonne": "ton",
1593
+ "tonnes": "tons",
1594
+ "towelled": "toweled",
1595
+ "towelling": "toweling",
1596
+ "toxaemia": "toxemia",
1597
+ "tranquillise": "tranquilize",
1598
+ "tranquillised": "tranquilized",
1599
+ "tranquilliser": "tranquilizer",
1600
+ "tranquillisers": "tranquilizers",
1601
+ "tranquillises": "tranquilizes",
1602
+ "tranquillising": "tranquilizing",
1603
+ "tranquillity": "tranquility",
1604
+ "tranquillize": "tranquilize",
1605
+ "tranquillized": "tranquilized",
1606
+ "tranquillizer": "tranquilizer",
1607
+ "tranquillizers": "tranquilizers",
1608
+ "tranquillizes": "tranquilizes",
1609
+ "tranquillizing": "tranquilizing",
1610
+ "tranquilly": "tranquility",
1611
+ "transistorised": "transistorized",
1612
+ "traumatise": "traumatize",
1613
+ "traumatised": "traumatized",
1614
+ "traumatises": "traumatizes",
1615
+ "traumatising": "traumatizing",
1616
+ "travelled": "traveled",
1617
+ "traveller": "traveler",
1618
+ "travellers": "travelers",
1619
+ "travelling": "traveling",
1620
+ "travelog": "travelogue",
1621
+ "travelogs": "travelogues",
1622
+ "trialled": "trialed",
1623
+ "trialling": "trialing",
1624
+ "tricolour": "tricolor",
1625
+ "tricolours": "tricolors",
1626
+ "trivialise": "trivialize",
1627
+ "trivialised": "trivialized",
1628
+ "trivialises": "trivializes",
1629
+ "trivialising": "trivializing",
1630
+ "tumour": "tumor",
1631
+ "tumours": "tumors",
1632
+ "tunnelled": "tunneled",
1633
+ "tunnelling": "tunneling",
1634
+ "tyrannise": "tyrannize",
1635
+ "tyrannised": "tyrannized",
1636
+ "tyrannises": "tyrannizes",
1637
+ "tyrannising": "tyrannizing",
1638
+ "tyre": "tire",
1639
+ "tyres": "tires",
1640
+ "unauthorised": "unauthorized",
1641
+ "uncivilised": "uncivilized",
1642
+ "underutilised": "underutilized",
1643
+ "unequalled": "unequaled",
1644
+ "unfavourable": "unfavorable",
1645
+ "unfavourably": "unfavorably",
1646
+ "unionisation": "unionization",
1647
+ "unionise": "unionize",
1648
+ "unionised": "unionized",
1649
+ "unionises": "unionizes",
1650
+ "unionising": "unionizing",
1651
+ "unorganised": "unorganized",
1652
+ "unravelled": "unraveled",
1653
+ "unravelling": "unraveling",
1654
+ "unrecognisable": "unrecognizable",
1655
+ "unrecognised": "unrecognized",
1656
+ "unrivalled": "unrivaled",
1657
+ "unsavoury": "unsavory",
1658
+ "untrammelled": "untrammeled",
1659
+ "urbanisation": "urbanization",
1660
+ "urbanise": "urbanize",
1661
+ "urbanised": "urbanized",
1662
+ "urbanises": "urbanizes",
1663
+ "urbanising": "urbanizing",
1664
+ "utilisable": "utilizable",
1665
+ "utilisation": "utilization",
1666
+ "utilise": "utilize",
1667
+ "utilised": "utilized",
1668
+ "utilises": "utilizes",
1669
+ "utilising": "utilizing",
1670
+ "valour": "valor",
1671
+ "vandalise": "vandalize",
1672
+ "vandalised": "vandalized",
1673
+ "vandalises": "vandalizes",
1674
+ "vandalising": "vandalizing",
1675
+ "vaporisation": "vaporization",
1676
+ "vaporise": "vaporize",
1677
+ "vaporised": "vaporized",
1678
+ "vaporises": "vaporizes",
1679
+ "vaporising": "vaporizing",
1680
+ "vapour": "vapor",
1681
+ "vapours": "vapors",
1682
+ "verbalise": "verbalize",
1683
+ "verbalised": "verbalized",
1684
+ "verbalises": "verbalizes",
1685
+ "verbalising": "verbalizing",
1686
+ "victimisation": "victimization",
1687
+ "victimise": "victimize",
1688
+ "victimised": "victimized",
1689
+ "victimises": "victimizes",
1690
+ "victimising": "victimizing",
1691
+ "videodisc": "videodisk",
1692
+ "videodiscs": "videodisks",
1693
+ "vigour": "vigor",
1694
+ "visualisation": "visualization",
1695
+ "visualisations": "visualizations",
1696
+ "visualise": "visualize",
1697
+ "visualised": "visualized",
1698
+ "visualises": "visualizes",
1699
+ "visualising": "visualizing",
1700
+ "vocalisation": "vocalization",
1701
+ "vocalisations": "vocalizations",
1702
+ "vocalise": "vocalize",
1703
+ "vocalised": "vocalized",
1704
+ "vocalises": "vocalizes",
1705
+ "vocalising": "vocalizing",
1706
+ "vulcanised": "vulcanized",
1707
+ "vulgarisation": "vulgarization",
1708
+ "vulgarise": "vulgarize",
1709
+ "vulgarised": "vulgarized",
1710
+ "vulgarises": "vulgarizes",
1711
+ "vulgarising": "vulgarizing",
1712
+ "waggon": "wagon",
1713
+ "waggons": "wagons",
1714
+ "watercolour": "watercolor",
1715
+ "watercolours": "watercolors",
1716
+ "weaselled": "weaseled",
1717
+ "weaselling": "weaseling",
1718
+ "westernisation": "westernization",
1719
+ "westernise": "westernize",
1720
+ "westernised": "westernized",
1721
+ "westernises": "westernizes",
1722
+ "westernising": "westernizing",
1723
+ "womanise": "womanize",
1724
+ "womanised": "womanized",
1725
+ "womaniser": "womanizer",
1726
+ "womanisers": "womanizers",
1727
+ "womanises": "womanizes",
1728
+ "womanising": "womanizing",
1729
+ "woollen": "woolen",
1730
+ "woollens": "woolens",
1731
+ "woollies": "woolies",
1732
+ "woolly": "wooly",
1733
+ "worshipped": "worshiped",
1734
+ "worshipper": "worshiper",
1735
+ "worshipping": "worshiping",
1736
+ "yodelled": "yodeled",
1737
+ "yodelling": "yodeling",
1738
+ "yoghourt": "yogurt",
1739
+ "yoghourts": "yogurts",
1740
+ "yoghurt": "yogurt",
1741
+ "yoghurts": "yogurts"
1742
+ }
distil-large-v3-init/preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "WhisperFeatureExtractor",
4
+ "feature_size": 128,
5
+ "hop_length": 160,
6
+ "n_fft": 400,
7
+ "n_samples": 480000,
8
+ "nb_max_frames": 3000,
9
+ "padding_side": "right",
10
+ "padding_value": 0.0,
11
+ "processor_class": "WhisperProcessor",
12
+ "return_attention_mask": false,
13
+ "sampling_rate": 16000
14
+ }
distil-large-v3-init/special_tokens_map.json ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|startoftranscript|>",
4
+ "<|en|>",
5
+ "<|zh|>",
6
+ "<|de|>",
7
+ "<|es|>",
8
+ "<|ru|>",
9
+ "<|ko|>",
10
+ "<|fr|>",
11
+ "<|ja|>",
12
+ "<|pt|>",
13
+ "<|tr|>",
14
+ "<|pl|>",
15
+ "<|ca|>",
16
+ "<|nl|>",
17
+ "<|ar|>",
18
+ "<|sv|>",
19
+ "<|it|>",
20
+ "<|id|>",
21
+ "<|hi|>",
22
+ "<|fi|>",
23
+ "<|vi|>",
24
+ "<|he|>",
25
+ "<|uk|>",
26
+ "<|el|>",
27
+ "<|ms|>",
28
+ "<|cs|>",
29
+ "<|ro|>",
30
+ "<|da|>",
31
+ "<|hu|>",
32
+ "<|ta|>",
33
+ "<|no|>",
34
+ "<|th|>",
35
+ "<|ur|>",
36
+ "<|hr|>",
37
+ "<|bg|>",
38
+ "<|lt|>",
39
+ "<|la|>",
40
+ "<|mi|>",
41
+ "<|ml|>",
42
+ "<|cy|>",
43
+ "<|sk|>",
44
+ "<|te|>",
45
+ "<|fa|>",
46
+ "<|lv|>",
47
+ "<|bn|>",
48
+ "<|sr|>",
49
+ "<|az|>",
50
+ "<|sl|>",
51
+ "<|kn|>",
52
+ "<|et|>",
53
+ "<|mk|>",
54
+ "<|br|>",
55
+ "<|eu|>",
56
+ "<|is|>",
57
+ "<|hy|>",
58
+ "<|ne|>",
59
+ "<|mn|>",
60
+ "<|bs|>",
61
+ "<|kk|>",
62
+ "<|sq|>",
63
+ "<|sw|>",
64
+ "<|gl|>",
65
+ "<|mr|>",
66
+ "<|pa|>",
67
+ "<|si|>",
68
+ "<|km|>",
69
+ "<|sn|>",
70
+ "<|yo|>",
71
+ "<|so|>",
72
+ "<|af|>",
73
+ "<|oc|>",
74
+ "<|ka|>",
75
+ "<|be|>",
76
+ "<|tg|>",
77
+ "<|sd|>",
78
+ "<|gu|>",
79
+ "<|am|>",
80
+ "<|yi|>",
81
+ "<|lo|>",
82
+ "<|uz|>",
83
+ "<|fo|>",
84
+ "<|ht|>",
85
+ "<|ps|>",
86
+ "<|tk|>",
87
+ "<|nn|>",
88
+ "<|mt|>",
89
+ "<|sa|>",
90
+ "<|lb|>",
91
+ "<|my|>",
92
+ "<|bo|>",
93
+ "<|tl|>",
94
+ "<|mg|>",
95
+ "<|as|>",
96
+ "<|tt|>",
97
+ "<|haw|>",
98
+ "<|ln|>",
99
+ "<|ha|>",
100
+ "<|ba|>",
101
+ "<|jw|>",
102
+ "<|su|>",
103
+ "<|yue|>",
104
+ "<|translate|>",
105
+ "<|transcribe|>",
106
+ "<|startoflm|>",
107
+ "<|startofprev|>",
108
+ "<|nospeech|>",
109
+ "<|notimestamps|>"
110
+ ],
111
+ "bos_token": {
112
+ "content": "<|endoftext|>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false
117
+ },
118
+ "eos_token": {
119
+ "content": "<|endoftext|>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false
124
+ },
125
+ "pad_token": {
126
+ "content": "<|endoftext|>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false
131
+ },
132
+ "unk_token": {
133
+ "content": "<|endoftext|>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false
138
+ }
139
+ }
distil-large-v3-init/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
distil-large-v3-init/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
distil-whisper/events.out.tfevents.1714645175.server02.624510.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7ca2a1958b3abb793b2ce49e63d14d6f42ffaee9b5d164ac0d50a6b4dd095d5
3
+ size 88
distil-whisper/events.out.tfevents.1715051424.server02.1325731.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a4ace7442e1821fa3c6bbadbec1d2b7e54ff922368b1fdf2a867425c20ac45f
3
+ size 1608
distil-whisper/events.out.tfevents.1715051868.server02.1327224.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06b63119c28bace4a28c57a82dd2c1bb212634ff4ae1b97def4415e6304532ab
3
+ size 696
distil_whisper.egg-info/PKG-INFO ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: distil_whisper
3
+ Version: 0.0.0
4
+ Summary: Toolkit for distilling OpenAI's Whisper model.
5
+ Description-Content-Type: text/markdown
6
+ Requires-Dist: torch>=1.10
7
+ Requires-Dist: transformers>=4.35.1
8
+ Requires-Dist: datasets[audio]>=2.14.7
9
+ Requires-Dist: accelerate>=0.24.1
10
+ Requires-Dist: jiwer
11
+ Requires-Dist: evaluate>=0.4.1
12
+ Requires-Dist: wandb
13
+ Requires-Dist: tensorboard
14
+ Requires-Dist: nltk
15
+ Provides-Extra: dev
16
+ Requires-Dist: ruff==0.1.5; extra == "dev"
17
+
18
+ ## Training Distil-Whisper
19
+
20
+ This sub-folder contains all the scripts required to train a Distil-Whisper model in your choice of language. They are
21
+ slightly modified from the original scripts used to distill Whisper for English ASR (as-per the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)).
22
+ The main difference is that these scripts are written in [PyTorch](https://pytorch.org), whereas the original scripts
23
+ are in [JAX](https://jax.readthedocs.io/en/latest/#)/[Flax](https://flax.readthedocs.io/en/latest/). These scripts are
24
+ also made to be easier to run end-to-end, whereas the original scripts require more steps and are somewhat hard-coded
25
+ for English ASR. Both sets of scripts achieve equivalent downstream results when the hyper-parameters are set equal.
26
+
27
+ If you are interested in reproducing the original Distil-Whisper checkpoints, we refer you to the sub-folder [Flax Training](./flax/README.md).
28
+ Otherwise, if you wish to distill Whisper on your own language/dataset, we recommend you use these scripts for ease of use
29
+ and the configurability they provide.
30
+
31
+ Reproducing the Distil-Whisper project requires four stages to be completed in successive order:
32
+
33
+ 1. [Pseudo-labelling](#1-pseudo-labelling)
34
+ 2. [Initialisation](#2-initialisation)
35
+ 3. [Training](#3-training)
36
+ 4. [Evaluation](#4-evaluation)
37
+
38
+ This README is partitioned according to the four stages. Each section provides a minimal example for running the
39
+ scripts used in the project. We will use a running example of distilling the Whisper model for Hindi speech recognition
40
+ on the Common Voice dataset. Note that this dataset only contains ~20 hours of audio data. Thus, it can be run extremely
41
+ quickly, but does not provide sufficient data to achieve optimal performance. We recommend training on upwards of 1000
42
+ hours of data should you want to match the performance of Whisper on high-resource languages.
43
+
44
+ ## Requirements
45
+
46
+ The Distil-Whisper training code is written in [PyTorch](https://pytorch.org) and [Accelerate](https://huggingface.co/docs/accelerate/index).
47
+ It heavily leverages the Whisper implementation in [🤗 Transformers](https://github.com/huggingface/transformers) for both
48
+ training and inference.
49
+
50
+ The instructions for installing the package are as follows:
51
+ 1. Install PyTorch from the [official instructions](https://pytorch.org/get-started/locally/), ensuring you install the correct version for your hardware and CUDA version.
52
+ 2. Fork the `distil-whisper` repository by clicking on the [fork](https://github.com/huggingface/distil-whisper/fork) button on the reopsitory's page
53
+ 3. Clone the `distil-whisper` repository and add the base repository as a remote. This will allow you to "pull" any upstream changes that are made to the base repository:
54
+
55
+ ```bash
56
+ git clone https://github.com/<your GitHub handle>/distil-whisper.git
57
+ cd distil-whisper
58
+ git remote add upstream https://github.com/huggingface/distil-whisper.git
59
+ ```
60
+ 4. pip install the required packages from the [setup.py](./setup.py) file:
61
+ ```bash
62
+ cd training
63
+ pip install -e .
64
+ cd ../..
65
+ ```
66
+
67
+ 5. Configure Accelerate by running the following command. Note that you should set the number of GPUs you wish to use for distillation, and also the data type (dtype) to your preferred dtype for training/inference (e.g. `bfloat16` on A100 GPUs, `float16` on V100 GPUs, etc.):
68
+
69
+ ```bash
70
+ accelerate config
71
+ ```
72
+
73
+ 6. The last thing we need to do is link our Hugging Face account so that we can pull/push model repositories on the Hub. This will allow us to save our final distilled weights on the Hub so that we can share them with the community. Run the command:
74
+
75
+ ```bash
76
+ git config --global credential.helper store
77
+ huggingface-cli login
78
+ ```
79
+ And then enter an authentication token from https://huggingface.co/settings/tokens. Create a new token if you do not have one already. You should make sure that this token has "write" privileges.
80
+
81
+ To confirm that you have a working environment, first accept the terms of use of the Common Voice 16.1 dataset on the Hub: https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1
82
+
83
+ You can run the following code cell to stream one sample of data from the Common Voice dataset, and check that you can
84
+ perform inference using the "tiny" Whisper model:
85
+
86
+ ```python
87
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
88
+ from datasets import load_dataset, Audio
89
+
90
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", low_cpu_mem_usage=True)
91
+ processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
92
+
93
+ model.to("cuda")
94
+
95
+ common_voice = load_dataset("mozilla-foundation/common_voice_16_1", "en", split="validation", streaming=True)
96
+ common_voice = common_voice.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
97
+
98
+ inputs = processor(next(iter(common_voice))["audio"]["array"], sampling_rate=16000, return_tensors="pt")
99
+ input_features = inputs.input_features
100
+
101
+ generated_ids = model.generate(input_features.to("cuda"), max_new_tokens=128)
102
+ pred_text = processor.decode(generated_ids[0], skip_special_tokens=True)
103
+
104
+ print("Pred text:", pred_text)
105
+ print("Environment set up successful?", generated_ids.shape[-1] == 20)
106
+ ```
107
+
108
+ ## 1. Pseudo-Labelling
109
+
110
+ The python script [`run_pseudo_labelling.py`](run_pseudo_labelling.py) is a flexible inference script that can be used
111
+ to generate pseudo-labels under a range of settings, including using both greedy and beam-search. It is also compatible
112
+ with [🤗 Datasets](https://github.com/huggingface/datasets) *streaming mode*, allowing users to load massive audio
113
+ datasets with **no disk space requirements**. For more information on streaming mode, the reader is referred to the
114
+ blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#streaming-mode-the-silver-bullet).
115
+
116
+ > As of the latest Distil-Whisper release, [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3), this
117
+ pseudo-labelling script also performs the added operation of concatenating (or packing) the audio inputs to 30-seconds.
118
+ Not only does this lead to a WER improvement when using sequential long-form decoding algorithm, but concatenating audios
119
+ to 30-seconds also improves the throughput during training, since the amount of zero-padding on the audio inputs is minimised.
120
+
121
+ The following script demonstrates how to pseudo-label the Hindi split of the Common Voice 16.1 dataset with greedy sampling:
122
+
123
+ ```bash
124
+ #!/usr/bin/env bash
125
+
126
+ accelerate launch run_pseudo_labelling.py \
127
+ --model_name_or_path "openai/whisper-large-v3" \
128
+ --dataset_name "mozilla-foundation/common_voice_16_1" \
129
+ --dataset_config_name "hi" \
130
+ --dataset_split_name "train+validation+test" \
131
+ --text_column_name "sentence" \
132
+ --id_column_name "path" \
133
+ --output_dir "./common_voice_16_1_hi_pseudo_labelled" \
134
+ --wandb_project "distil-whisper-labelling" \
135
+ --per_device_eval_batch_size 64 \
136
+ --dtype "bfloat16" \
137
+ --attn_implementation "sdpa" \
138
+ --logging_steps 500 \
139
+ --max_label_length 256 \
140
+ --concatenate_audio \
141
+ --preprocessing_batch_size 500 \
142
+ --preprocessing_num_workers 8 \
143
+ --dataloader_num_workers 8 \
144
+ --report_to "wandb" \
145
+ --language "hi" \
146
+ --task "transcribe" \
147
+ --return_timestamps \
148
+ --streaming False \
149
+ --generation_num_beams 1 \
150
+ --push_to_hub
151
+ ```
152
+
153
+ On an 80 GB A100 GPU, the following script takes approximately 5 minutes to concatenate and pre-process the 20 hours of
154
+ audio data, and a further 10 minutes to transcribe the pseudo-labels. The pseudo-labelled dataset corresponding to this
155
+ script is available on the Hugging Face Hub under [sanchit-gandhi/common_voice_16_1_hi_pseudo_labelled](https://huggingface.co/datasets/sanchit-gandhi/common_voice_16_1_hi_pseudo_labelled).
156
+ The WER of the pre-trained Whisper large-v3 model is 17.2% on the test split. We will compare the performance of our distilled model against this number.
157
+
158
+ There are two noteworthy arguments that configure the dataset concatenation (or packing) process:
159
+ 1. `concatenate_audio`: whether or not to concatenate (or pack) the audios to 30-second chunks. The latest Distil-Whisper model, [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3#differences-with-distil-large-v2), highlights the WER improvements obtained using the sequential long-form decoding algorithm when concatenated audios are used. Concatenating audios to 30-seconds also improves the throughput during training, since the amount of zero-padding on the audio inputs is minimised. Hence, it is highly recommended to set `--concatenate_audio=True`.
160
+ 2. `preprocessing_batch_size`: the batch size to use when concatenating (or packing) the audios. Using a larger batch size results in a greater portion of audio samples being packed to 30-seconds, at the expense of higher memory consumption. If you exceed your system's RAM when performing the concatenation operation, reduce the `preprocessing_batch_size` by a factor of 2 to 250 or even 125.
161
+ 3. `preprocessing_num_workers`: the number of multiprocessing workers to use when concatenating the audios. Using more workers will result in faster pre-processing, at the expense of higher memory consumption. Ensure you do not exceed the maximum number of CPUs on your device.
162
+
163
+ In addition, the following arguments configure the inference of the Whisper model:
164
+ 1. `language`: explicitly setting the language token during inference substantially improves the generation performance of the Whisper model, since the model is forced always to predict in the given language. We recommend you set the language to the language you wish to distil the Whisper model on. The only exception is when distilling an English-only model (i.e. where the model id is appended with an `.en`, e.g. `small.en`), the language argument should be set to None, since there is no language token used during training/inference.
165
+ 2. `return_timestamps`: whether or not to predict timestamps in the pseudo-labels. Timestamp prediction is required should you want your distilled model to be able to predict timestamps at inference time (e.g. for the original OpenAI long-form transcription algorithm). However, the pseudo-labels are marginally less accurate than not using timestamps. We recommend pseudo-labelling **with** timestamps to ensure the distilled model is as general as possible.
166
+ 3. `attn_implementation`: which attention implementation to use for inference. Set to `sdpa` for [PyTorch SDPA](https://huggingface.co/docs/transformers/v4.35.2/en/perf_infer_gpu_one#bettertransformer), or `flash_attn_2` if your hardware supports Flash Attention 2 and you have the [package installed](https://github.com/Dao-AILab/flash-attention).
167
+ 4. `streaming`: whether or not to use Datasets' streaming mode. If enabled, the audio data will be streamed from the Hugging Face Hub with no disk space requirements. However, the user is then responsible for adding the pseudo-labels to the dataset script in a follow-up step (see [Using Streaming Mode](#TODO)). If set to `False`, the audio data will be downloaded and pre-processed offline. At the end of pseudo-labelling, the pseudo-labels will be automatically appended to the original dataset, meaning the dataset is ready to be used for the subsequent training step without any additional steps.
168
+ 5. `generation_num_beams`: how many beams to use while decoding. In practice, we found the distilled model to perform comparably when the data was pseudo-labelled with `generation_num_beams=1` (greedy) or `generation_num_beams>1` (beam). This is likely because the WER filter compensates for the lower quality pseudo-labels obtained using greedy search. However, using `generation_num_beams=1` gives substantially faster inference time for the pseudo-labelling step, and so we recommend this configuration.
169
+
170
+ Should you have your own audio dataset, you can first [convert it](https://huggingface.co/docs/datasets/audio_dataset) to
171
+ Hugging Face Datasets format and push it to the Hugging Face Hub. You can then pseudo-label it using the script above,
172
+ replacing the `--dataset_name` with the name of your dataset on the Hub.
173
+
174
+ Otherwise, you may wish to use an open-source dataset already available on the Hugging Face Hub. We provide a summary of
175
+ the three most popular multilingual datasets in the table below. For more details, refer to the blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#multilingual-speech-recognition).
176
+
177
+ | Dataset | Languages | Domain | Speaking Style | License | Text Column | ID Column |
178
+ |-----------------------------------------------------------------------------------------------|-----------|---------------------------------------|----------------|-----------|---------------------|--------------|
179
+ | [Multilingual LibriSpeech](https://huggingface.co/datasets/facebook/multilingual_librispeech) | 6 | Audiobooks | Narrated | CC-BY-4.0 | `"text"` | `"id"` |
180
+ | [Common Voice 16](https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1) | 120 | Wikipedia text & crowd-sourced speech | Narrated | CC0-1.0 | `"sentence"` | `"path"` |
181
+ | [VoxPopuli](https://huggingface.co/datasets/facebook/voxpopuli) | 15 | European Parliament recordings | Spontaneous | CC0 | `"normalized_text"` | `"audio_id"` |
182
+
183
+ To achieve *robustness* to different distributions of audio data, it is recommended to train on multiple datasets where possible.
184
+ For example, the above three datasets all have splits for the German language. Thus, if distilling a Whisper model for German,
185
+ it would be wise to use a combination of the three datasets during training, in order to cover at least three distinct domains
186
+ (audiobooks, crowd-sourced speech, parliament recordings). You may wish to use a combination of open-source datasets, or
187
+ a combination of open-source and individually owned datasets to cover multiple distributions and domains.
188
+
189
+ ## 2. Initialisation
190
+
191
+ The script [`create_student_model.py`](create_student_model.py) can be used to initialise a small student model
192
+ from a large teacher model. When initialising a student model with fewer layers than the teacher model, the student is
193
+ initialised by copying maximally spaced layers from the teacher, as per the [DistilBart](https://arxiv.org/abs/2010.13002)
194
+ recommendations.
195
+
196
+ First, we need to create a model repository on the Hugging Face Hub. This repository will contain all the required files
197
+ to reproduce the training run, alongside model weights, training logs and a README.md card. You can either create a model
198
+ repository directly on the Hugging Face Hub using the link: https://huggingface.co/new. Or, via the CLI, as we'll show here.
199
+
200
+ Let's pick a name for our distilled model: `distil-whisper-large-v3-hi`. We can run the following command to create a repository under this name:
201
+
202
+ ```bash
203
+ huggingface-cli repo create distil-whisper-large-v3-hi
204
+ ```
205
+
206
+ We can now see the model on the Hub, e.g. under https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi
207
+
208
+ Let's clone the repository so that we can place our training script and model weights inside:
209
+
210
+ ```bash
211
+ git lfs install
212
+ git clone https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi
213
+ ```
214
+
215
+ Be sure to change the repo address to `https://huggingface.co/<your-user-name>/<your-repo-name>`
216
+
217
+ We can now copy the relevant training scrips to the repository:
218
+ ```bash
219
+ cd distil-whisper-large-v3-hi
220
+
221
+ cp ../distil-whisper/training/create_student_model.py .
222
+ cp ../distil-whisper/training/run_distillation.py .
223
+ ```
224
+
225
+ The following command demonstrates how to initialise a student model from the Whisper [large-v3](https://huggingface.co/openai/whisper-large-v3)
226
+ checkpoint, with all 32 encoder layer and 2 decoder layers. The 2 student decoder layers are copied from teacher layers
227
+ 1 and 32 respectively, as the maximally spaced layers:
228
+
229
+ ```bash
230
+ #!/usr/bin/env bash
231
+
232
+ python create_student_model.py \
233
+ --teacher_checkpoint "openai/whisper-large-v3" \
234
+ --encoder_layers 32 \
235
+ --decoder_layers 2 \
236
+ --save_dir "./distil-large-v3-init"
237
+ ```
238
+
239
+ The initialised model will be saved to the sub-directory `distil-large-v3-init` in our model repository.
240
+
241
+ ## 3. Training
242
+
243
+ The script [`run_distillation.py`](run_distillation.py) is an end-to-end script for loading multiple
244
+ datasets, a student model, a teacher model, and performing teacher-student distillation. It uses the loss formulation
245
+ from the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430), which is a weighted sum of the cross-entropy and
246
+ KL-divergence loss terms.
247
+
248
+ The following command takes the Common Voice dataset that was pseudo-labelled in the first stage and trains the
249
+ 2-layer decoder model intialised in the previous step. We pass the local path to the pseudo-labelled Common Voice dataset
250
+ (`../common_voice_16_1_hi_pseudo_labelled`), which you can change to the path where your local pseudo-labelled dataset is
251
+ saved.
252
+
253
+ In this example, we will combine the train and validation splits to give our training set, and evaluate on the test split
254
+ only. This is purely to demonstrate how to combine multiple pseudo-labelled datasets for training, rather than recommended
255
+ advice for defining train/validation splits. We advise that you train on the train splits of your dataset, evaluate and
256
+ tune hyper-parameters on the validation split, and only test the final checkpoint on the test split. Note how multiple
257
+ training datasets and splits can be loaded by separating the dataset arguments by `+` symbols. Thus, the script generalises
258
+ to any number of training datasets.
259
+
260
+ ```bash
261
+ #!/usr/bin/env bash
262
+
263
+ accelerate launch run_distillation.py \
264
+ --model_name_or_path "./distil-large-v3-init" \
265
+ --teacher_model_name_or_path "openai/whisper-large-v3" \
266
+ --train_dataset_name "../common_voice_16_1_hi_pseudo_labelled+../common_voice_16_1_hi_pseudo_labelled" \
267
+ --train_split_name "train+validation" \
268
+ --text_column_name "sentence+sentence" \
269
+ --train_dataset_samples "7+4" \
270
+ --eval_dataset_name "../common_voice_16_1_hi_pseudo_labelled" \
271
+ --eval_split_name "test" \
272
+ --eval_text_column_name "sentence" \
273
+ --eval_steps 1000 \
274
+ --save_steps 1000 \
275
+ --warmup_steps 50 \
276
+ --learning_rate 0.0001 \
277
+ --lr_scheduler_type "constant_with_warmup" \
278
+ --timestamp_probability 0.2 \
279
+ --condition_on_prev_probability 0.2 \
280
+ --language "hi" \
281
+ --task "transcribe" \
282
+ --logging_steps 25 \
283
+ --save_total_limit 1 \
284
+ --max_steps 5000 \
285
+ --wer_threshold 20 \
286
+ --per_device_train_batch_size 32 \
287
+ --per_device_eval_batch_size 32 \
288
+ --dataloader_num_workers 8 \
289
+ --preprocessing_num_workers 8 \
290
+ --ddp_timeout 7200 \
291
+ --dtype "bfloat16" \
292
+ --attn_implementation "sdpa" \
293
+ --output_dir "./" \
294
+ --do_train \
295
+ --do_eval \
296
+ --gradient_checkpointing \
297
+ --overwrite_output_dir \
298
+ --predict_with_generate \
299
+ --freeze_encoder \
300
+ --freeze_embed_positions \
301
+ --streaming False \
302
+ --push_to_hub
303
+
304
+ ```
305
+
306
+ The above training script will take approximately 3 hours to complete on an 80 GB A100 GPU and yield a final WER of 76%.
307
+ While the generations are starting to take form, there is still a 59% WER gap to the teacher model. This is hardly
308
+ surprising give we only have 15 hours of un-filtered data, and closer to just 1.5 hours with data filtering.
309
+ As mentioned above, using upwards of 1000 hours of data and training for 10k steps will likely yield
310
+ more competitive performance. For the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430), we trained on 21k hours
311
+ of audio data for 80k steps. We found that upwards of 13k hours of audio data was required to reach convergence on English
312
+ ASR (see Section 9.2 of the [paper](https://arxiv.org/abs/2311.00430)), so the more data you have, the better!
313
+
314
+ Scaling to multiple GPUs using [distributed data parallelism (DDP)](https://pytorch.org/tutorials/beginner/ddp_series_theory.html)
315
+ is trivial: simply run `accelerate config` and select the multi-GPU option, specifying the IDs of the GPUs you wish to use. The
316
+ above script can then be run using DDP with no code changes.
317
+
318
+ Training logs will be reported to TensorBoard and WandB, provided the relevant packages are available. An example of a
319
+ saved checkpoint pushed to the Hugging Face Hub can be found here: [sanchit-gandhi/distil-whisper-large-v3-hi](https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi).
320
+
321
+ There are a few noteworthy data arguments:
322
+ 1. `train_dataset_samples`: defines the number of training samples in each dataset. Used to calculate the sampling probabilities in the dataloader. A good starting point is setting the samples to the number of hours of audio data in each split. A more refined strategy is setting it to the number of training samples in each split, however this might require downloading the dataset offline to compute these statistics.
323
+ 2. `wer_threshold`: sets the WER threshold between the normalised pseudo-labels and normalised ground truth labels. Any samples with WER > `wer_threshold` are discarded from the training data. This is beneficial to avoid training the student model on pseudo-labels where Whisper hallucinated or got the predictions grossly wrong. In our English distillation experiments, we found a WER threshold of 10% provides the optimal trade-off between ensuring high-quality transcriptions, and not filtering unnecessary amounts of training data. For multilingual distillation, the threshold should be set in accordance with the WER achieved by the pre-trained model on the test set.
324
+ 3. `streaming`: whether or not to use Datasets' streaming mode. Recommended for large datasets, where the audio data can be streamed from the Hugging Face Hub with no disk space requirements.
325
+ 4. `timestamp_probability`: the per-sample probability for retaining timestamp tokens in the labels (should they contain them). Retaining some portion of timestamp tokens in the training data is required to ensure the distilled model can predict timestamps at inference time. In our experiments, we found that training on timestamps with high-probability hurts the distilled model's transcription performance. Thus, we recommend setting this to a value below 0.5. Typically, a value of 0.2 works well, giving good transcription and timestamp performance.
326
+ 5. `condition_on_prev_probability`: the per-sample probability for conditioning on previous labels. Conditioning on previous tokens is required to ensure the distilled model can be used with the "sequential" long-form transcription algorithm at inference time. We did not experiment with this parameter, but found values around 0.2 to provide adequate performance. OpenAI pre-trained Whisper on with a 50% probability for conditioning on previous tokens. Thus, you might wish to try higher values.
327
+
328
+ As well as a few noteworthy model arguments that can be configured to give optimal training performance:
329
+ 1. `freeze_encoder`: whether to freeze the entire encoder of the student model during training. Beneficial when the student encoder is copied exactly from the teacher encoder. In this case, the encoder hidden-states from the teacher model are re-used for the student model. Stopping the gradient computation through the encoder and sharing the encoder hidden-states provides a significant memory saving, and can enable up to 2x batch sizes.
330
+ 2. `freeze_embed_positions`: whether to freeze the student model's decoder positional embeddings. Using the same embed positions as the teacher model, which is designed to handle context lengths up to 448 tokens, helps the student model retain its input id representation up to the full max input length.
331
+ 3. `dtype`: data type (dtype) in which the model computation should be performed. Note that this only controls the dtype of the computations (forward and backward pass), and not the dtype of the parameters or optimiser states.
332
+
333
+ And finally, a few noteworthy training arguments:
334
+ 1. `max_steps`: defines the total number of optimisation steps (forward + backward pass) during training. To reach convergence, you should use a dataset of at least 1k hours and train for a minimum of 50k steps.
335
+ 2. `lr_scheduler_stype`: defines the learning rate schedule, one of `constant_with_warmup` or `linear`. When experimenting with a training set-up or training for very few steps (< 5k), using `constant_with_warmup` is typically beneficial, since the learning rate remains high over the short training run. When performing long training runs (> 5k), using a `linear` schedule generally results in superior downstream performance of the distilled model.
336
+
337
+ TODO:
338
+ - [ ] Template for model cards
339
+
340
+ ## 4. Evaluation
341
+
342
+ There are four types of evaluation performed in Distil-Whisper:
343
+ 1. Short form: evaluation on audio samples less than 30s in duration. Examples include typical ASR test sets, such as the LibriSpeech validation set.
344
+ 2. Sequential long form: evaluation on audio samples longer than 30s in duration using the original "sequential" long-form algorithm. Examples include entire TED talks or earnings calls.
345
+ 3. Chunked long form: evaluation on audio samples longer than 30s in duration using the Transformers "chunked" long-form algorithm.
346
+ 4. Speculative decoding: evaluation on audio samples less than 30s in duration, where a faster, distilled model is used as the assistant to a slower, teacher model.
347
+
348
+ All four forms of evaluation are performed using the script [`run_eval.py`](run_eval.py). Unlike the pseudo-labelling
349
+ and training scripts, the evaluation script assumes that only one GPU accelerator is used. We can copy the corresponding
350
+ evaluation script to the model repository using the following command:
351
+
352
+ ```bash
353
+ cp ../distil-whisper/training/run_eval.py .
354
+ ```
355
+
356
+ Models are assessed jointly using:
357
+ 1. The *word-error rate (WER)* metric: measures the numer of substitution, deletion and insertion errors relative to the total number of words. A lower WER indicates a more accurate model.
358
+ 2. The *inverse real-time factor (RTFx)* metric: measures the ratio of `audio input time : model compute time`. A higher RTFx indicates a faster model.
359
+
360
+ In all cases, it is particularly important to evaluate the final model on data that is *out-of-distribution (OOD)* with
361
+ the training data. Evaluating on OOD data provides insight as to how well the distilled model is likely to generalise to
362
+ different audio distributions at inference time. In our example, the Common Voice test set is *in-distribution (ID)*
363
+ with our training data, since it is taken from the same distribution as the Common Voice training set. Whereas the FLEURS
364
+ test set is OOD, since it is not used as part of the training set.
365
+
366
+ ### Short Form
367
+
368
+ The script [`run_eval.py`](run_eval.py) can be used to evaluate a trained student model over multiple short-form
369
+ validation sets. The following example demonstrates how to evaluate the student model trained in the previous step on
370
+ the Common Voice `test` set (ID) and also the FLEURS `test` set (OOD). Again, it leverages streaming mode to bypass
371
+ the need to download the data offline:
372
+
373
+ ```bash
374
+ #!/usr/bin/env bash
375
+
376
+ python run_eval.py \
377
+ --model_name_or_path "./" \
378
+ --dataset_name "../common_voice_16_1_hi_pseudo_labelled+google/fleurs" \
379
+ --dataset_config_name "default+hi_in" \
380
+ --dataset_split_name "test+test" \
381
+ --text_column_name "sentence+transcription" \
382
+ --batch_size 16 \
383
+ --dtype "bfloat16" \
384
+ --generation_max_length 256 \
385
+ --language "hi" \
386
+ --attn_implementation "sdpa" \
387
+ --streaming
388
+
389
+ ```
390
+
391
+ The student model achieves an average WER of TODO% with an RTFx of TODO for a batch size of 16. We can easily adapt the above
392
+ script to evaluate the teacher model, simply by switching the `model_name_or_path` to `openai/whisper-large-v3`, which
393
+ achieves an average WER of TODO% with an RTFx of TODO. Therefore, for a batch size of 16, the student model is a factor of TODO
394
+ times faster than the teacher. The WER gap can be closed by training on more data (at least 1k hours) for more training
395
+ steps (at least 50k).
396
+
397
+ ### Sequential Long Form
398
+
399
+ The original Whisper paper presents a long-form transcription algorithm that sequentially transcribes 30-second segments
400
+ of audio and shifts the sliding window according to the timestamps predicted by the model. This style of sequential
401
+ inference is performed directly using the [`.generate`](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate)
402
+ method in Transformers.
403
+
404
+ The script [`run_eval.py`](run_eval.py) can be used to evaluate the trained student model on an arbitrary number of
405
+ long-form evaluation sets using the sequential algorithm. Since we don't have a long-form validation set for Hindi to hand,
406
+ in this example we'll evaluate the official Distil-Whisper model [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3)
407
+ on the TED-LIUM validation set:
408
+
409
+ ```bash
410
+ #!/usr/bin/env bash
411
+
412
+ accelerate launch run_eval.py \
413
+ --model_name_or_path "distil-whisper/distil-large-v3" \
414
+ --dataset_name "distil-whisper/tedlium-long-form" \
415
+ --dataset_config_name "default" \
416
+ --dataset_split_name "validation" \
417
+ --text_column_name "text" \
418
+ --batch_size 16 \
419
+ --dtype "bfloat16" \
420
+ --generation_max_length 256 \
421
+ --language "en" \
422
+ --attn_implementation "sdpa" \
423
+ --streaming
424
+
425
+ ```
426
+
427
+ ### Chunked Long Form
428
+
429
+ Chunked long form evaluation runs on the premise that a single long audio file can be *chunked* into smaller segments and
430
+ inferred in parallel. The resulting transcriptions are then joined at the boundaries to give the final text prediction.
431
+ A small overlap (or *stride*) is used between adjacent segments to ensure a continuous transcription across chunks.
432
+
433
+ This style of chunked inference is performed using the [`pipeline`](https://huggingface.co/docs/transformers/main_classes/pipelines)
434
+ class, which provides a wrapper around the [`.generate`](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate)
435
+ function for long-form inference.
436
+
437
+ The script [`run_eval.py`](run_eval.py) can be used to evaluate the trained student model on an arbitrary number of
438
+ long-form evaluation sets using the pipeline class. Again, in this example we'll evaluate distil-large-v3 on the
439
+ TED-LIUM validation set:
440
+
441
+ ```bash
442
+ #!/usr/bin/env bash
443
+
444
+ python run_eval.py \
445
+ --model_name_or_path "openai/whisper-large-v3" \
446
+ --dataset_name "distil-whisper/tedlium-long-form" \
447
+ --dataset_config_name "default" \
448
+ --dataset_split_name "validation" \
449
+ --text_column_name "text" \
450
+ --use_pipeline \
451
+ --chunk_length_s 25.0 \
452
+ --language "en" \
453
+ --return_timestamps \
454
+ --dtype "bfloat16" \
455
+ --streaming
456
+
457
+ ```
458
+
459
+ The argument `chunk_length_s` controls the length of the chunked audio samples. It should be set to match the typical
460
+ length of audio the student model was trained on. If unsure about what value of `chunk_length_s` is optimal for your case,
461
+ it is recommended to run a *sweep* over all possible values. A template script for running a [WandB sweep](https://docs.wandb.ai/guides/sweeps)
462
+ can be found under [`run_chunk_length_s_sweep.yaml`](flax/long_form_transcription_scripts/run_chunk_length_s_sweep.yaml).
463
+
464
+ ### Speculative Decoding
465
+
466
+ Speculative decoding, or assisted generation, relies on the premise that a faster, assistant model can be used to speed-up
467
+ the generation of a slower, assistant model. Speculative decoding mathematically ensures that exactly the same outputs as
468
+ Whisper are obtained, while being ~2 times faster. This makes it the perfect drop-in replacement for existing Whisper
469
+ pipelines, since exactly the same outputs are guaranteed.
470
+
471
+ Distil-Whisper checkpoints can be designed to be efficient assistant models to Whisper for speculative decoding. More precisely,
472
+ by freezing the encoder during training, the distilled model can share the same encoder weights as Whisper during inference, since
473
+ the encoder weights are un-changed. In doing so, only the distilled 2-layer decoder has to be loaded in addition to the
474
+ original Whisper model, which is approximately an 8% increase to the total parameter count, with up to 2x faster inference
475
+ for low batch sizes. For more details on speculative decoding, the reader is advised to refer to the following blog post:
476
+ [Speculative Decoding for 2x Faster Whisper Inference](https://huggingface.co/blog/whisper-speculative-decoding).
477
+
478
+ In the example below, we use our distilled model as an assistant to the large-v3 teacher model during inference:
479
+
480
+ ```bash
481
+ #!/usr/bin/env bash
482
+
483
+ python run_eval.py \
484
+ --model_name_or_path "openai/whisper-large-v3" \
485
+ --assistant_model_name_or_path "./" \
486
+ --dataset_name "../common_voice_16_1_hi_pseudo_labelled+google/fleurs" \
487
+ --dataset_config_name "default+hi_in" \
488
+ --dataset_split_name "test+test" \
489
+ --text_column_name "sentence+transcription" \
490
+ --batch_size 16 \
491
+ --dtype "bfloat16" \
492
+ --generation_max_length 256 \
493
+ --language "hi" \
494
+ --attn_implementation "sdpa" \
495
+ --streaming
496
+
497
+ ```
498
+
499
+ We see that we achieve a WER of TODO%, the same as what we obtained with the large-v3 model, but with an RTFx of TODO,
500
+ a factor of TODO faster than using the large-v3 model alone. The RTFx value can be improved by training the student on
501
+ more data and for more training steps, since this will improve the number of predicted tokens that match the teacher
502
+ predictions.
503
+
504
+ ## Overview of Training Methods
505
+
506
+ ### 1. Fine-Tuning
507
+
508
+ For fine-tuning, we take the original Whisper checkpoint and train it on one or more datasets using the standard
509
+ cross-entropy loss. As such, there is no involvement from the teacher checkpoint during training, and so the fine-tuned
510
+ model is permitted to *overfit* to the distribution of the training data we provide. This makes it appealing for "low-resource"
511
+ languages where the original Whisper model performs poorly, since we can boost the performance of the model on a single
512
+ language by *overfitting* to that distribution of data. Note that this means the fine-tuned model is prone to loosing
513
+ its robustness to different audio distributions, which is the trade-off with improving performance on a specified dataset.
514
+
515
+ As a rule of thumb, fine-tuning is appropriate for languages where the original Whisper model performs > 20% WER, and we
516
+ have a relatively small quantity of training data available (< 1000 hours). With fine-tuning, we require as little as **10 hours**
517
+ of training data to significantly boost the performance of the Whisper model. For an in-depth guide to fine-tuning Whisper,
518
+ the reader is advised to refer to the blog post: [Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers](https://huggingface.co/blog/fine-tune-whisper).
519
+
520
+ ### 2. Shrink and Fine-Tune
521
+
522
+ Shrink and fine-tune (SFT) is a knowledge distillation (KD) technique in which we first *shrink* the teacher model to a
523
+ smaller student model by copying maximally spaced layers, and then *fine-tune* the student model on the cross-entropy loss
524
+ as described above. Typically, we retain the full encoder from the Whisper model and only shrink the decoder. Retaining
525
+ the entire encoder helps significantly with maintaining Whisper's robustness to different audio distributions (_c.f._
526
+ Section 9.3 of the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)).
527
+
528
+ We can either train the student model on a dataset of (audio, text) pairs as above. Or, we can use the pre-trained
529
+ Whisper model to generate *pseudo-labels* for our audio data, and train on the (audio, pseudo-label) pairs.
530
+
531
+ Pseudo-labels can be used when either:
532
+ 1. The original text transcriptions are normalised (lower-cased or no punctuation): the Whisper generated pseudo-labels contain both punctuation and casing, and so can be used as a substitute for the normalised transcriptions
533
+ 2. The pre-trained Whisper model achieves < 20% WER on the languages: we then know the majority of the pseudo-labels will be accurate enough for us to train on.
534
+
535
+ They are not recommended when both of the following are true:
536
+ 1. The original text is punctuated and cased
537
+ 2. The pre-trained Whisper model achieves > 20% WER on the languages: in this case, we want to overfit to the particular distribution of the language, and so train directly on the original text data
538
+
539
+ To discard inaccurate pseudo-labels during training, we employ a simple WER heuristic to filter our pseudo-labelled
540
+ training data. We first normalise the original text and the pseudo-labelled text using the Whisper normaliser. If the
541
+ WER between the normalised text exceeds a 10% WER threshold, we discard the training sample. Else, we retain it for training.
542
+ Section 9.1 of the Distil-Whisper [paper](https://arxiv.org/abs/2311.00430) demonstrates the importance of using this
543
+ threshold for training.
544
+
545
+ ### 3. KL Divergence
546
+
547
+ In the KL Divergence setting, the student model is initialised by shrinking the teacher as before, and then trained to
548
+ match the predictions of the teacher during training.
549
+
550
+ ### Summary of Methods
551
+
552
+ The following table summarises the two training paradigms: fine-tuning and knowledge distillation (KD). It suggests
553
+ minimum values for the pre-trained WER / training data to achieve reasonable performance:
554
+
555
+ | Method | Pre-Trained WER / % | Training Data / h |
556
+ |-------------|---------------------|-------------------|
557
+ | Fine-tuning | > 20 | < 1000 |
558
+ | KD | < 20 | > 1000 |
559
+
560
+ ## Acknowledgements
561
+
562
+ * OpenAI for the Whisper [model](https://huggingface.co/openai/whisper-large-v3) and [original codebase](https://github.com/openai/whisper)
563
+ * Hugging Face 🤗 [Transformers](https://github.com/huggingface/transformers) for the Whisper model implementation
564
+ * Google's [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) program for Cloud TPU v4s used to train the official Distil-Whisper models
565
+ * The Hugging Face 🤗 cluster for enabling experimentation with the PyTorch scripts
566
+
567
+ ## Citation
568
+
569
+ If you use this code-base, please consider citing the Distil-Whisper paper:
570
+
571
+ ```
572
+ @misc{gandhi2023distilwhisper,
573
+ title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
574
+ author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
575
+ year={2023},
576
+ eprint={2311.00430},
577
+ archivePrefix={arXiv},
578
+ primaryClass={cs.CL}
579
+ }
580
+ ```
distil_whisper.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ setup.py
4
+ distil_whisper.egg-info/PKG-INFO
5
+ distil_whisper.egg-info/SOURCES.txt
6
+ distil_whisper.egg-info/dependency_links.txt
7
+ distil_whisper.egg-info/requires.txt
8
+ distil_whisper.egg-info/top_level.txt
distil_whisper.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
distil_whisper.egg-info/requires.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.10
2
+ transformers>=4.35.1
3
+ datasets[audio]>=2.14.7
4
+ accelerate>=0.24.1
5
+ jiwer
6
+ evaluate>=0.4.1
7
+ wandb
8
+ tensorboard
9
+ nltk
10
+
11
+ [dev]
12
+ ruff==0.1.5
distil_whisper.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
flax/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
flax/Makefile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ check_dirs := .
2
+
3
+ quality:
4
+ black --check $(check_dirs)
5
+ ruff $(check_dirs)
6
+
7
+ style:
8
+ black $(check_dirs)
9
+ ruff $(check_dirs) --fix
flax/README.md ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Reproducing Distil-Whisper
2
+
3
+ This sub-folder contains all the training and inference scripts to reproduce the Distil-Whisper project. Distil-Whisper
4
+ is written in JAX to leverage the fast training and inference speed offered by TPU v4 hardware. However, it also works
5
+ efficiently on GPU hardware without any additional code changes.
6
+
7
+ Reproducing the Distil-Whisper project requires four stages to be completed in successive order:
8
+
9
+ 1. [Pseudo-labelling](#pseudo-labelling)
10
+ 2. [Initialisation](#initialisation)
11
+ 3. [Training](#training)
12
+ 4. [Evaluation](#evaluation)
13
+
14
+ This README is partitioned according to the four stages. Each section provides a minimal example for running the
15
+ scripts used in the project. The final scripts used to train the model are referenced in-line.
16
+
17
+ It is worth noting that the experiments performed in JAX/Flax have been on English ASR only. For multilingual training code,
18
+ the [PyTorch Training Code](../README.md) can easily be used, facilitating anyone to run Whisper distillation on a language of their choice.
19
+
20
+ ## Requirements
21
+
22
+ Distil-Whisper is written in Python, JAX and Flax, and heavily leverages the Flax Whisper implementation in
23
+ [🤗 Transformers](https://github.com/huggingface/transformers). The instructions for installing the package are as follows:
24
+ 1. Install JAX from the [official instructions](https://github.com/google/jax#installation), ensuring you install the correct version for your hardware (GPU or TPU).
25
+ 2. Install the `distil_whisper` package by cloning the repository and performing an editable installation:
26
+
27
+ ```bash
28
+ git clone https://github.com/huggingface/distil-whisper.git
29
+ cd distil-whisper/training/flax
30
+ pip install -e .
31
+ ```
32
+
33
+ ## Pseudo-Labelling
34
+
35
+ Pseudo-labelling is the process of generating target text predictions for the input audio data using the teacher model.
36
+ The generated text labels then replace the ground truth text labels when performing distillation. The rationale for
37
+ using pseudo-labels instead of ground truth labels is to circumvent the issue of inconsistent transcription formatting
38
+ across datasets.
39
+
40
+ The python script [`run_pseudo_labelling.py`](run_pseudo_labelling.py) is a flexible inference script that can be used
41
+ to generate pseudo-labels under a range of settings, including using both greedy and beam-search. It is also compatible
42
+ with [🤗 Datasets](https://github.com/huggingface/datasets) *streaming mode*, allowing users to load massive audio
43
+ datasets with **no disk space requirements**. For more information on streaming mode, the reader is referred to the
44
+ blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#streaming-mode-the-silver-bullet).
45
+
46
+ The following script demonstrates how to pseudo-label the [LibriSpeech 960h](https://huggingface.co/datasets/librispeech_asr)
47
+ dataset with greedy sampling and streaming mode:
48
+
49
+ ```bash
50
+ #!/usr/bin/env bash
51
+
52
+ python run_pseudo_labelling.py \
53
+ --model_name_or_path "openai/whisper-large-v2" \
54
+ --dataset_name "librispeech_asr" \
55
+ --dataset_config_name "all" \
56
+ --data_split_name "train.clean.100+train.clean.360+train.other.500" \
57
+ --text_column_name "text" \
58
+ --output_dir "./transcriptions" \
59
+ --per_device_eval_batch_size 16 \
60
+ --max_label_length 256 \
61
+ --dtype "bfloat16" \
62
+ --report_to "wandb" \
63
+ --dataloader_num_workers 16 \
64
+ --streaming \
65
+ --push_to_hub \
66
+ --generation_num_beams 1 # for greedy, set >1 for beam
67
+
68
+ ```
69
+
70
+ The script will save the generated pseudo-labels alongside the file ids to the output directory `output_dir`. Adding the
71
+ `--push_to_hub` argument uploads the generated pseudo-labels to the Hugging Face Hub on save.
72
+
73
+ The directory [`pseudo_labelling_scripts`](pseudo_labelling_scripts) contains a collection of bash scripts for
74
+ pseudo-labelling all 10 audio datasets used in the project. The datasets with the Whisper generated transcriptions
75
+ can be found on the Hugging Face Hub under the [Distil Whisper organisation](https://huggingface.co/datasets?sort=trending&search=distil-whisper%2F).
76
+ They can be re-used should you wish to bypass the data labelling stage of the reproduction.
77
+
78
+ <!--- TODO(SG): Combine PS with source audio to create dataset --->
79
+
80
+ ## Initialisation
81
+
82
+ The script [`create_student_model.py`](create_student_model.py) can be used to initialise a small student model
83
+ from a large teacher model. When initialising a student model with fewer layers than the teacher model, the student is
84
+ initialised by copying maximally spaced layers from the teacher, as per the [DistilBart](https://arxiv.org/abs/2010.13002)
85
+ recommendations.
86
+
87
+ The following command demonstrates how to initialise a student model from the [large-v2](https://huggingface.co/openai/whisper-large-v2)
88
+ checkpoint, with all 32 encoder layer and 2 decoder layers. The 2 student decoder layers are copied from teacher layers
89
+ 1 and 32 respectively, as the maximally spaced layers.
90
+
91
+ ```bash
92
+ #!/usr/bin/env bash
93
+
94
+ python create_student_model.py \
95
+ --teacher_checkpoint "openai/whisper-large-v2" \
96
+ --encoder_layers 32 \
97
+ --decoder_layers 2 \
98
+ --save_dir "./large-32-2" \
99
+ --push_to_hub
100
+ ```
101
+
102
+
103
+ ## Training
104
+
105
+ The script [`run_distillation.py`](run_distillation.py) is an end-to-end script for loading multiple
106
+ datasets, a student model, a teacher model, and performing teacher-student distillation. It uses the loss formulation
107
+ from [DistilBart](https://arxiv.org/abs/2010.13002), which is a combination of a cross-entropy, KL-divergence and
108
+ mean-square error (MSE) loss:
109
+
110
+ https://github.com/huggingface/distil-whisper/blob/4dd831543e6c40b1159f1ec951db7f4fe0e86850/run_distillation.py#L1725
111
+
112
+ The weight assigned to the MSE loss is configurable. The others are fixed to the values from the DistilBART paper.
113
+
114
+ The following command takes the LibriSpeech 960h dataset that was pseudo-labelled in the first stage and trains the
115
+ 2-layer decoder model intialised in the previous step. Note that multiple training datasets and splits can be loaded
116
+ by separating the dataset arguments by `+` symbols. Thus, the script generalises to any number of training datasets.
117
+
118
+ ```bash
119
+ #!/usr/bin/env bash
120
+
121
+ python3 run_distillation.py \
122
+ --model_name_or_path "./large-32-2" \
123
+ --teacher_model_name_or_path "openai/whisper-large-v2" \
124
+ --train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr" \
125
+ --train_dataset_config_name "all+all+all" \
126
+ --train_split_name "train.clean.100+train.clean.360+train.other.500" \
127
+ --train_dataset_samples "100+360+500" \
128
+ --eval_dataset_name "librispeech_asr" \
129
+ --eval_dataset_config_name "all" \
130
+ --eval_split_name "validation.clean" \
131
+ --eval_steps 5000 \
132
+ --save_steps 5000 \
133
+ --warmup_steps 500 \
134
+ --learning_rate 0.0001 \
135
+ --lr_scheduler_type "constant_with_warmup" \
136
+ --logging_steps 25 \
137
+ --save_total_limit 1 \
138
+ --max_steps 20000 \
139
+ --wer_threshold 10 \
140
+ --per_device_train_batch_size 64 \
141
+ --per_device_eval_batch_size 64 \
142
+ --dataloader_num_workers 16 \
143
+ --dtype "bfloat16" \
144
+ --output_dir "./" \
145
+ --do_train \
146
+ --do_eval \
147
+ --use_scan \
148
+ --gradient_checkpointing \
149
+ --overwrite_output_dir \
150
+ --predict_with_generate \
151
+ --freeze_encoder \
152
+ --streaming \
153
+ --use_auth_token \
154
+ --push_to_hub
155
+
156
+ ```
157
+
158
+ The above training script will take approximately 20 hours to complete on a TPU v4-8 and yield a final WER of 2.3%.
159
+
160
+ Training logs will be reported to TensorBoard and WandB, provided the relevant packages are available. An example of a
161
+ saved checkpoint pushed to the Hugging Face Hub can be found here: [large-32-2](https://huggingface.co/distil-whisper/large-32-2).
162
+
163
+ There are a few noteworthy arguments that can be configured to give optimal training performance:
164
+ * `train_dataset_samples`: defines the number of training samples in each dataset. Used to calculate the sampling probabilities in the dataloader. A good starting point is setting the samples to the number of hours of audio data in each split. A more refined strategy is setting it to the number of training samples in each split, however this might require downloading the dataset offline to compute these statistics.
165
+ * `wer_threshold`: sets the WER threshold between the normalised pseudo-labels and normalised ground truth labels. Any samples with WER > `wer_threshold` are discarded from the training data. This is beneficial to avoid training the student model on pseudo-labels where Whisper hallucinated or got the predictions grossly wrong.
166
+ * `freeze_encoder`: whether to freeze the entire encoder of the student model during training. Beneficial when the student encoder is copied exactly from the teacher encoder. In this case, the encoder hidden-states from the teacher model are re-used for the student model. Stopping the gradient computation through the encoder and sharing the encoder hidden-states provides a significant memory saving, and can enable up to 2x batch sizes.
167
+ * `dtype`: data type (dtype) in which the model computation should be performed. Note that this only controls the dtype of the computations (forward and backward pass), and not the dtype of the parameters or optimiser states.
168
+
169
+ The Distil Whisper project extends the above script to train on a combined dataset formed from 12 open-source ASR datasets,
170
+ totalling 22k hours and over 50k speakers. Template scripts to run training on this composite dataset can be found
171
+ in the directory [`distillation_scripts`](distillation_scripts).
172
+
173
+ ## Evaluation
174
+
175
+ There are two types of evaluation performed in Distil-Whisper:
176
+ 1. Short form: evaluation on audio samples less than 30s in duration. Examples include typical ASR test sets, such as the LibriSpeech validation set.
177
+ 2. Long form: evaluation on audio samples longer than 30s in duration. Examples include entire TED talks or earnings calls.
178
+
179
+ Both forms of evaluation are performed using the *word-error rate (WER)* metric.
180
+
181
+ ### Short Form
182
+
183
+ The script [`run_eval.py`](run_eval.py) can be used to evaluate a trained student model over multiple validation sets.
184
+ The following example demonstrates how to evaluate the student model trained in the previous step on the LibriSpeech
185
+ `validation.clean` and `validation.other` dev sets. Again, it leverages streaming mode to bypass the need to download
186
+ the data offline:
187
+
188
+ ```bash
189
+ #!/usr/bin/env bash
190
+
191
+ python run_eval.py \
192
+ --model_name_or_path "./large-32-2" \
193
+ --dataset_name "librispeech_asr+librispeech_asr" \
194
+ --dataset_config_name "all+all" \
195
+ --dataset_split_name "validation.clean+validation.other" \
196
+ --output_dir "./large-32-2" \
197
+ --per_device_eval_batch_size 64 \
198
+ --dtype "bfloat16" \
199
+ --dataloader_num_workers 16 \
200
+ --report_to "wandb" \
201
+ --streaming \
202
+ --predict_with_generate
203
+
204
+ ```
205
+
206
+ ### Long Form
207
+
208
+ Long form evaluation runs on the premise that a single long audio file can be *chunked* into smaller segments and
209
+ inferred in parallel. The resulting transcriptions are then joined at the boundaries to give the final text prediction.
210
+ A small overlap (or *stride*) is used between adjacent segments to ensure a continuous transcription across chunks.
211
+
212
+ This style of chunked inference is performed using the [`FlaxWhisperPipeline`](https://github.com/huggingface/distil-whisper/blob/6426022e3b3a0a498b4150a636b54e2e3898bf1a/distil_whisper/pipeline.py#L61)
213
+ class, which is heavily inspired from [Whisper JAX](https://github.com/sanchit-gandhi/whisper-jax/tree/main#pipeline-usage).
214
+
215
+ The script [`run_long_form_transcription.py`](run_long_form_transcription.py) can be used to evaluate the trained
216
+ student model on an arbitrary number of long-form evaluation sets. The following script demonstrates how to evaluate
217
+ the example student model on two such test sets, [Earnings 21](https://huggingface.co/datasets/distil-whisper/earnings21)
218
+ and [Earnings 22](https://huggingface.co/datasets/distil-whisper/earnings22):
219
+
220
+ ```bash
221
+ #!/usr/bin/env bash
222
+
223
+ python run_long_form_transcription.py \
224
+ --model_name_or_path "./large-32-2" \
225
+ --dataset_name "distil-whisper/earnings21+distil-whisper/earnings22" \
226
+ --dataset_config_name "default+default" \
227
+ --dataset_split_name "test+test+test+test" \
228
+ --text_column_name "transcription+transcription" \
229
+ --output_dir "./large-32-2" \
230
+ --per_device_eval_batch_size 64 \
231
+ --chunk_length_s 15 \
232
+ --dtype "bfloat16" \
233
+ --report_to "wandb" \
234
+ --streaming
235
+
236
+ ```
237
+
238
+ The argument `chunk_length_s` controls the length of the chunked audio samples. It should be set to match the typical
239
+ length of audio the student model was trained on. If unsure about what value of `chunk_length_s` is optimal for your case,
240
+ it is recommended to run a *sweep* over all possible values. A template script for running a [WandB sweep](https://docs.wandb.ai/guides/sweeps)
241
+ can be found under [`run_chunk_length_s_sweep.yaml`](long_form_transcription_scripts/run_chunk_length_s_sweep.yaml).
242
+
243
+ ### 1. Pseudo Labelling
244
+
245
+ #### Greedy vs Beam
246
+
247
+ We found there to be little-to-no difference in the downstream performance of the distilled model after pseudo labelling
248
+ using either greedy or beam-search. We attribute this to the minimal difference in performance of the pre-trained Whisper
249
+ model under greedy and beam-search decoding, giving pseudo-labelled transcriptions of similar quality. We encourage
250
+ users to generate pseudo-labels using greedy decoding given it runs significantly faster. Beam search is only advised if
251
+ the pre-trained model is hallucinating significantly on the audio inputs, in which case it helps reduce the frequency and
252
+ severity of hallucinations. If using beam search, the number of beams can be kept low: even 2 beams helps reduce the
253
+ amount of hallucinations significantly.
254
+
255
+ #### Timestamps
256
+
257
+ Whisper is trained on a timestamp prediction task as part of the pre-training set-up. Here, a fixed proportion of the
258
+ pre-training data includes sequence-level *timestamps* as part of the transcription labels:
259
+
260
+ ```bash
261
+ <|0.00|> Hey, this is a test transcription. <|3.42|>
262
+ ```
263
+
264
+ Timestamp prediction is useful for enriching the transcriptions with timing information for downstream tasks, such as
265
+ aligning the Whisper transcription with the output of a speaker diarization system, and also reduces the frequency of
266
+ hallucinations.
267
+
268
+ The pseudo-labelling scrip [`run_pseudo_labelling.py`](run_pseudo_labelling.py) can be extended to predict timestamp
269
+ information in the audio data by appending the `--return_timestamps` flag to the launch command. The timestamped labelled
270
+ data can be passed to the training script in exactly the same way as the non-timestamped version, and the pre-processing
271
+ function will take care of encoding the timestamps and appending the required task tokens.
272
+
273
+ #### Previous Context
274
+
275
+ Whisper is also pre-trained on a prompting task, where the transcription for the preceding utterance is fed as context
276
+ to the current one:
277
+
278
+ ```bash
279
+ <|startofprev|> This is the previous context from the preceding utterance.<|startoftranscript|> And this is the current utterance.<|endoftranscript|>
280
+ ```
281
+
282
+ Annotating the transcriptions with previous context labels is only possible for datasets where we have consecutive files
283
+ and unique speaker ids, since we need to ensure segment `i` directly follows on from segment `i-1` if we use it as the
284
+ prompt.
285
+
286
+ As per the Whisper paper, we mask out the loss over the previous context tokens. At inference time, we can replace the
287
+ previous context with a “prompt” to encourage the model to generate text in the style of the prompt (i.e. for specific
288
+ named entities, or styles of transcription)
289
+
290
+ ## Acknowledgements
291
+
292
+ * 🤗 Hugging Face Transformers for the base Whisper implementation
293
+ * Google's [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) programme for their generous provision of Cloud TPUs
flax/conversion_scripts/run_convert_distilled_train_state_to_hf.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python convert_train_state_to_hf.py \
4
+ --model_name_or_path "distil-whisper/large-32-2" \
5
+ --output_dir "./" \
6
+ --resume_from_checkpoint "checkpoint-15000" \
7
+ --cache_dir "/home/sanchitgandhi/.cache" \
8
+ --use_scan
flax/convert_train_state_to_hf.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Convert a Flax training state to HF Transformers Whisper weights.
18
+ """
19
+
20
+ import logging
21
+ import os
22
+ import sys
23
+ from dataclasses import field
24
+ from pathlib import Path
25
+ from typing import Callable, Optional
26
+
27
+ import flax
28
+ import jax
29
+ import jax.numpy as jnp
30
+ import optax
31
+ from flax import jax_utils, traverse_util
32
+ from flax.serialization import from_bytes
33
+ from flax.training import train_state
34
+ from flax.training.common_utils import shard_prng_key
35
+ from huggingface_hub import Repository, create_repo
36
+ from optax._src import linear_algebra
37
+ from transformers import (
38
+ AutoConfig,
39
+ HfArgumentParser,
40
+ Seq2SeqTrainingArguments,
41
+ )
42
+ from transformers.file_utils import get_full_repo_name
43
+ from transformers.utils import check_min_version
44
+ from transformers.utils.versions import require_version
45
+
46
+ from distil_whisper import FlaxWhisperForConditionalGeneration
47
+
48
+
49
+ # initialise JAX for multi-host set-up on TPU
50
+ jax.distributed.initialize()
51
+
52
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
53
+ check_min_version("4.27.0.dev0")
54
+
55
+ require_version(
56
+ "datasets>=1.18.0",
57
+ "To fix: pip install -r examples/flax/speech-recogintion/requirements.txt",
58
+ )
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ @flax.struct.dataclass
64
+ class ModelArguments:
65
+ """
66
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
67
+ """
68
+
69
+ model_name_or_path: str = field(
70
+ metadata={"help": ("Path to pretrained student model or model identifier from huggingface.co/models")}
71
+ )
72
+ config_name: Optional[str] = field(
73
+ default=None,
74
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
75
+ )
76
+ cache_dir: Optional[str] = field(
77
+ default=None,
78
+ metadata={"help": ("Where to store the pretrained models downloaded from huggingface.co")},
79
+ )
80
+ use_fast_tokenizer: bool = field(
81
+ default=True,
82
+ metadata={"help": ("Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.")},
83
+ )
84
+ model_revision: str = field(
85
+ default="main",
86
+ metadata={"help": ("The specific model version to use (can be a branch name, tag name or commit id).")},
87
+ )
88
+ use_auth_token: bool = field(
89
+ default=False,
90
+ metadata={
91
+ "help": (
92
+ "Will use the token generated when running `transformers-cli login`"
93
+ " (necessary to use this script with private models)."
94
+ )
95
+ },
96
+ )
97
+ dtype: Optional[str] = field(
98
+ default="float32",
99
+ metadata={
100
+ "help": (
101
+ "Floating-point format in which the model weights should be initialized"
102
+ " and trained. Choose one of `[float32, float16, bfloat16]`."
103
+ )
104
+ },
105
+ )
106
+ load_with_scan_weights: bool = field(
107
+ default=False,
108
+ metadata={
109
+ "help": "Whether the pre-trained checkpoint has its weights stored in scan format. Set to True for scanned "
110
+ "weights, defaults to False for non-scan (unrolled) weights."
111
+ },
112
+ )
113
+ use_scan: bool = field(
114
+ default=True,
115
+ metadata={"help": ("Whether or not to use `scan_with_axes` over the encoder and decoder blocks.")},
116
+ )
117
+
118
+
119
+ def create_learning_rate_fn(
120
+ num_train_steps: int, lr_scheduler_type: str, num_warmup_steps: int, learning_rate: float
121
+ ) -> Callable[[int], jnp.array]:
122
+ """Returns a linear warmup, linear_decay learning rate function."""
123
+ lr_scheduler_types = ("linear", "constant_with_warmup")
124
+
125
+ if lr_scheduler_type not in lr_scheduler_types:
126
+ raise ValueError(
127
+ f"lr_scheduler_type of type {lr_scheduler_type} not supported, choose from {lr_scheduler_types}."
128
+ )
129
+
130
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
131
+ decay_fn = optax.linear_schedule(
132
+ init_value=learning_rate,
133
+ end_value=0 if lr_scheduler_type == "linear" else learning_rate,
134
+ transition_steps=num_train_steps - num_warmup_steps,
135
+ )
136
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
137
+ return schedule_fn
138
+
139
+
140
+ class TrainState(train_state.TrainState):
141
+ dropout_rng: jnp.ndarray
142
+ max_grad_norm: float
143
+
144
+ def apply_gradients(self, *, grads, **kwargs):
145
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value, clipping the
146
+ gradients by the maximum grad norm.
147
+
148
+ Note that internally this function calls `.tx.update()` followed by a call
149
+ to `optax.apply_updates()` to update `params` and `opt_state`.
150
+
151
+ Args:
152
+ grads: Gradients that have the same pytree structure as `.params`.
153
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
154
+
155
+ Returns:
156
+ An updated instance of `self` with `step` incremented by one, `params`
157
+ and `opt_state` updated by applying `grads`, and additional attributes
158
+ replaced as specified by `kwargs`.
159
+ """
160
+ # clip gradients by global l2 norm
161
+ g_norm = linear_algebra.global_norm(grads)
162
+ g_norm = jnp.maximum(self.max_grad_norm, g_norm)
163
+ grads = jax.tree_map(lambda t: (t / g_norm) * self.max_grad_norm, grads)
164
+
165
+ updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
166
+ new_params = optax.apply_updates(self.params, updates)
167
+
168
+ return self.replace(
169
+ step=self.step + 1,
170
+ params=new_params,
171
+ opt_state=new_opt_state,
172
+ **kwargs,
173
+ )
174
+
175
+ def replicate(self):
176
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
177
+
178
+ def unreplicate(self):
179
+ return jax_utils.unreplicate(self)
180
+
181
+
182
+ def main():
183
+ # 1. Parse input arguments
184
+ # See all possible arguments in src/transformers/training_args.py
185
+ # or by passing the --help flag to this script.
186
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
187
+ parser = HfArgumentParser(
188
+ (
189
+ ModelArguments,
190
+ Seq2SeqTrainingArguments,
191
+ )
192
+ )
193
+
194
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
195
+ # If we pass only one argument to the script and it's the path to a json file,
196
+ # let's parse it to get our arguments.
197
+ model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
198
+ else:
199
+ model_args, training_args = parser.parse_args_into_dataclasses()
200
+
201
+ # Handle the repository creation
202
+ if training_args.push_to_hub:
203
+ if training_args.hub_model_id is None:
204
+ repo_name = get_full_repo_name(
205
+ Path(training_args.output_dir).absolute().name,
206
+ token=training_args.hub_token,
207
+ )
208
+ else:
209
+ repo_name = training_args.hub_model_id
210
+ create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
211
+ repo = Repository(
212
+ training_args.output_dir,
213
+ clone_from=repo_name,
214
+ token=training_args.hub_token,
215
+ )
216
+
217
+ # 5. Load pretrained config, model and processor
218
+ config = AutoConfig.from_pretrained(
219
+ (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
220
+ cache_dir=model_args.cache_dir,
221
+ revision=model_args.model_revision,
222
+ use_auth_token=True if model_args.use_auth_token else None,
223
+ )
224
+ student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained(
225
+ model_args.model_name_or_path,
226
+ config=config,
227
+ dtype=getattr(jnp, model_args.dtype),
228
+ cache_dir=model_args.cache_dir,
229
+ revision=model_args.model_revision,
230
+ use_auth_token=True if model_args.use_auth_token else None,
231
+ _do_init=False,
232
+ use_scan=model_args.load_with_scan_weights,
233
+ )
234
+
235
+ # enable scan / gradient checkpointing if necessary in the student model
236
+ if model_args.use_scan:
237
+ student_model.enable_scan() # to enable scan in the nn.Module
238
+ student_params = student_model.convert_unroll_to_scan(student_params) # to convert the unrolled params to scan
239
+
240
+ # Initialize our student state
241
+ rng = jax.random.PRNGKey(training_args.seed)
242
+ rng, dropout_rng = jax.random.split(rng)
243
+
244
+ total_train_steps = int(training_args.max_steps)
245
+
246
+ # Create learning rate schedule
247
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
248
+ total_train_steps,
249
+ training_args.lr_scheduler_type,
250
+ training_args.warmup_steps,
251
+ training_args.learning_rate,
252
+ )
253
+
254
+ # We use Optax's "masking" functionality to not apply weight decay
255
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
256
+ # mask boolean with the same structure as the parameters.
257
+ # The mask is True for parameters that should be decayed.
258
+ def decay_mask_fn(params):
259
+ flat_params = traverse_util.flatten_dict(params)
260
+ # find out all LayerNorm parameters
261
+ layer_norm_candidates = [
262
+ "layer_norm",
263
+ "self_attn_layer_norm",
264
+ "final_layer_norm",
265
+ "encoder_attn_layer_norm",
266
+ ]
267
+ layer_norm_named_params = {
268
+ layer[-2:]
269
+ for layer_norm_name in layer_norm_candidates
270
+ for layer in flat_params.keys()
271
+ if layer_norm_name in "".join(layer).lower()
272
+ }
273
+ flat_mask = {path: path[-1] != "bias" and path[-2:] not in layer_norm_named_params for path in flat_params}
274
+ return traverse_util.unflatten_dict(flat_mask)
275
+
276
+ # create adam optimizer
277
+ adamw = optax.adamw(
278
+ learning_rate=linear_decay_lr_schedule_fn,
279
+ b1=training_args.adam_beta1,
280
+ b2=training_args.adam_beta2,
281
+ eps=training_args.adam_epsilon,
282
+ weight_decay=training_args.weight_decay,
283
+ mask=decay_mask_fn,
284
+ )
285
+
286
+ # Setup train state
287
+ student_state = TrainState.create(
288
+ apply_fn=student_model.__call__,
289
+ params=student_params,
290
+ tx=adamw,
291
+ dropout_rng=dropout_rng,
292
+ max_grad_norm=training_args.max_grad_norm,
293
+ )
294
+
295
+ if training_args.resume_from_checkpoint is not None:
296
+ if os.path.isfile(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")):
297
+ logger.info(
298
+ f"Checkpoint detected, resuming training at {training_args.resume_from_checkpoint}. To avoid "
299
+ "this behavior, omit the resume_from_checkpoint argument."
300
+ )
301
+ with Path(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")).open("rb") as f:
302
+ student_state = from_bytes(student_state, f.read())
303
+ else:
304
+ logger.warning(
305
+ f"Checkpoint {training_args.resume_from_checkpoint} not detected, training from scratch. Ensure "
306
+ f"you pass the path to a folder with a valid checkpoint for your model."
307
+ )
308
+
309
+ cur_step = int(jax.device_get(student_state.step))
310
+
311
+ # save weights in HF Transformers format
312
+ if jax.process_index() == 0:
313
+ student_model.disable_scan()
314
+ student_state_params = student_model.convert_scan_to_unroll(student_state.params)
315
+ student_params = jax.device_get(student_state_params)
316
+ student_model.save_pretrained(
317
+ os.path.join(training_args.output_dir, f"checkpoint-{cur_step}"), params=student_params
318
+ )
319
+ if training_args.push_to_hub:
320
+ repo.push_to_hub(
321
+ commit_message=f"Saving weights of step {cur_step}",
322
+ blocking=False,
323
+ )
324
+
325
+
326
+ if __name__ == "__main__":
327
+ main()
flax/create_student_model.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Initialise a student Whisper model from a pre-trained teacher model for
18
+ teacher-student distillation.
19
+ """
20
+
21
+ import argparse
22
+ import copy
23
+ import logging
24
+
25
+ import jax
26
+ import numpy as np
27
+ from flax.core import freeze, unfreeze
28
+ from transformers import GenerationConfig, WhisperFeatureExtractor, WhisperProcessor
29
+
30
+ from distil_whisper import FlaxWhisperForConditionalGeneration
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ def parse_args():
37
+ parser = argparse.ArgumentParser(
38
+ description="Initialise a student Whisper model from a teacher model, copying the relevant layer weights and adjusting the processor as necessary."
39
+ )
40
+ parser.add_argument(
41
+ "--teacher_checkpoint",
42
+ type=str,
43
+ required=True,
44
+ help="The HF Hub ID of the teacher checkpoint.",
45
+ )
46
+ parser.add_argument(
47
+ "--subfolder",
48
+ type=str,
49
+ default="",
50
+ help="In case the relevant teacher weights are located inside a subfolder of the model repo on huggingface.co, you "
51
+ "can specify the folder name here.",
52
+ )
53
+ parser.add_argument(
54
+ "--encoder_layers",
55
+ type=int,
56
+ default=None,
57
+ help="Number of encoder layers to use in the student model. Defaults to all layers from the teacher.",
58
+ )
59
+ parser.add_argument(
60
+ "--decoder_layers",
61
+ type=int,
62
+ default=2,
63
+ help="Number of decoder layers to use in the student model. Defaults to 2 layers.",
64
+ )
65
+ parser.add_argument(
66
+ "--max_source_positions",
67
+ type=int,
68
+ default=None,
69
+ help="The maximum sequence length of log-mel filter-bank features that this model might ever be used with. Can "
70
+ "be used to create a student model with a shorter context length than the teacher model. Defaults to the number "
71
+ "of source positions in the teacher model (1500).",
72
+ )
73
+ parser.add_argument(
74
+ "--save_dir",
75
+ type=str,
76
+ required=True,
77
+ help="Where to save the student weights and processor.",
78
+ )
79
+ parser.add_argument(
80
+ "--push_to_hub",
81
+ type=bool,
82
+ required=False,
83
+ default=False,
84
+ help="Whether to push the student weights and processor to the Hub.",
85
+ )
86
+ parser.add_argument(
87
+ "--cache_dir",
88
+ type=str,
89
+ default=None,
90
+ help="Where to store the pretrained models downloaded from huggingface.co",
91
+ )
92
+
93
+ args = parser.parse_args()
94
+ return args
95
+
96
+
97
+ def init_student_model_from_teacher(
98
+ teacher_checkpoint,
99
+ encoder_layers=None,
100
+ decoder_layers=2,
101
+ max_source_positions=None,
102
+ save_dir=None,
103
+ push_to_hub=None,
104
+ cache_dir=None,
105
+ subfolder="",
106
+ ):
107
+ teacher_model, teacher_params = FlaxWhisperForConditionalGeneration.from_pretrained(
108
+ teacher_checkpoint,
109
+ _do_init=False,
110
+ cache_dir=cache_dir,
111
+ subfolder=subfolder,
112
+ )
113
+ processor = WhisperProcessor.from_pretrained(teacher_checkpoint)
114
+ generation_config = GenerationConfig.from_pretrained(teacher_checkpoint)
115
+
116
+ teacher_config = teacher_model.config
117
+ teacher_encoder_layers = teacher_config.encoder_layers
118
+ teacher_decoder_layers = teacher_config.decoder_layers
119
+
120
+ student_config = copy.deepcopy(teacher_config)
121
+ student_config.update(
122
+ {
123
+ "encoder_layers": encoder_layers if encoder_layers is not None else teacher_encoder_layers,
124
+ "decoder_layers": decoder_layers,
125
+ "max_source_positions": (
126
+ max_source_positions if max_source_positions is not None else student_config.max_source_positions
127
+ ),
128
+ }
129
+ )
130
+
131
+ encoder_mapping = np.linspace(0, teacher_encoder_layers - 1, student_config.encoder_layers, dtype=int)
132
+ encoder_mapping[-1] = teacher_encoder_layers - 1
133
+
134
+ encoder_map = {}
135
+ for student_layer, teacher_layer in enumerate(encoder_mapping):
136
+ encoder_map[str(teacher_layer)] = str(student_layer)
137
+
138
+ decoder_mapping = np.linspace(0, teacher_decoder_layers - 1, student_config.decoder_layers, dtype=int)
139
+ decoder_mapping[-1] = teacher_decoder_layers - 1
140
+
141
+ decoder_map = {}
142
+ for student_layer, teacher_layer in enumerate(decoder_mapping):
143
+ decoder_map[str(teacher_layer)] = str(student_layer)
144
+
145
+ # init the student params from the teacher model
146
+ student_params = unfreeze(teacher_params)
147
+ student_params["model"]["decoder"]["layers"] = {}
148
+
149
+ for layer in teacher_params["model"]["decoder"]["layers"]:
150
+ if layer in decoder_map:
151
+ # re-introduce pre-defined layers from the teacher
152
+ student_params["model"]["decoder"]["layers"][decoder_map[layer]] = teacher_params["model"]["decoder"][
153
+ "layers"
154
+ ][layer]
155
+
156
+ if encoder_layers is not None:
157
+ student_params["model"]["encoder"]["layers"] = {}
158
+ for layer in teacher_params["model"]["encoder"]["layers"]:
159
+ if layer in encoder_map:
160
+ # re-introduce pre-defined layers from the teacher
161
+ student_params["model"]["encoder"]["layers"][encoder_map[layer]] = teacher_params["model"]["encoder"][
162
+ "layers"
163
+ ][layer]
164
+
165
+ if max_source_positions is not None:
166
+ # slice the first MAX_SOURCE_POSITIONS embedding weights
167
+ student_params["model"]["encoder"]["embed_positions"]["embedding"] = teacher_params["model"]["encoder"][
168
+ "embed_positions"
169
+ ]["embedding"][: student_config.max_source_positions, :]
170
+ # update the feature extractor to handle the new input length
171
+ chunk_length = int(student_config.max_source_positions * 2 / 100)
172
+ processor.feature_extractor = WhisperFeatureExtractor(chunk_length=chunk_length)
173
+
174
+ # remove the teacher params and model
175
+ del teacher_params, teacher_model
176
+
177
+ # save the converted weights and model
178
+ student_params = freeze(student_params)
179
+ student_model = FlaxWhisperForConditionalGeneration(student_config, _do_init=False)
180
+
181
+ if save_dir is not None:
182
+ student_model.save_pretrained(save_dir, params=student_params)
183
+ # we also need to correctly save the processor and generation config
184
+ processor.save_pretrained(save_dir)
185
+ generation_config.save_pretrained(save_dir)
186
+
187
+ # check we can do a forward pass with the saved model - first load the weights and processor
188
+ logger.info("Checking we can load the saved model...")
189
+ student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained(
190
+ save_dir,
191
+ _do_init=False,
192
+ )
193
+ processor = WhisperProcessor.from_pretrained(save_dir)
194
+
195
+ # define some random inputs
196
+ input_features = processor(np.ones(16000), sampling_rate=16000, return_tensors="np").input_features
197
+ decoder_start_token_id = student_model.config.decoder_start_token_id
198
+ decoder_input_ids = np.ones((input_features.shape[0], 1)) * decoder_start_token_id
199
+
200
+ # do a forward pass - outputs will be gibberish for the initialised model so we can't check them
201
+ logger.info("Checking we can run the converted model forward...")
202
+ _ = student_model(input_features, decoder_input_ids=decoder_input_ids, params=student_params).logits
203
+ logger.info("Conversion successful!")
204
+
205
+ if push_to_hub:
206
+ student_model.push_to_hub(save_dir, params=student_params)
207
+ processor.push_to_hub(save_dir)
208
+ generation_config.push_to_hub(save_dir)
209
+
210
+
211
+ if __name__ == "__main__":
212
+ args = parse_args()
213
+
214
+ # Set the verbosity to info of the logger - we only want one process per machine to log things on the screen
215
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
216
+
217
+ init_student_model_from_teacher(
218
+ teacher_checkpoint=args.teacher_checkpoint,
219
+ encoder_layers=args.encoder_layers,
220
+ decoder_layers=args.decoder_layers,
221
+ max_source_positions=args.max_source_positions,
222
+ save_dir=args.save_dir,
223
+ push_to_hub=args.push_to_hub,
224
+ cache_dir=args.cache_dir,
225
+ subfolder=args.subfolder,
226
+ )
flax/distil_whisper/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __version__ = "0.0.1"
17
+
18
+ from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
19
+ from .partitioner import PjitPartitioner
20
+ from .pipeline import FlaxWhisperPipeline
21
+ from .train_state import InferenceState
flax/distil_whisper/layers.py ADDED
@@ -0,0 +1,1338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Dense attention classes and mask/weighting functions."""
16
+
17
+ # pylint: disable=attribute-defined-outside-init,g-bare-generic
18
+
19
+ import dataclasses
20
+ import functools
21
+ import operator
22
+ from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union
23
+
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import numpy as np
27
+ from flax import linen as nn
28
+ from flax.linen import partitioning as nn_partitioning
29
+ from flax.linen.dtypes import promote_dtype
30
+ from jax import lax, random
31
+
32
+
33
+ # from flax.linen.partitioning import param_with_axes, with_sharding_constraint
34
+ param_with_axes = nn_partitioning.param_with_axes
35
+ with_sharding_constraint = nn_partitioning.with_sharding_constraint
36
+
37
+
38
+ # Type annotations
39
+ Array = jnp.ndarray
40
+ DType = jnp.dtype
41
+ PRNGKey = jnp.ndarray
42
+ Shape = Iterable[int]
43
+ Activation = Callable[..., Array]
44
+ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]]
45
+ DotGeneralT = Callable[..., Array]
46
+ ConvGeneralDilatedT = Callable[..., Array]
47
+ PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]]
48
+ LaxPadding = Union[str, Sequence[Tuple[int, int]]]
49
+
50
+ # Parameter initializers.
51
+ Initializer = Callable[[PRNGKey, Shape, DType], Array]
52
+ InitializerAxis = Union[int, Tuple[int, ...]]
53
+ NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array]
54
+
55
+ default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
56
+
57
+
58
+ # ------------------------------------------------------------------------------
59
+ # Temporary inlined JAX N-d initializer code
60
+ # TODO(levskaya): remove once new JAX release is out.
61
+ # ------------------------------------------------------------------------------
62
+ def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
63
+ """Inlined JAX `nn.initializer._compute_fans`."""
64
+ if isinstance(in_axis, int):
65
+ in_size = shape[in_axis]
66
+ else:
67
+ in_size = int(np.prod([shape[i] for i in in_axis]))
68
+ if isinstance(out_axis, int):
69
+ out_size = shape[out_axis]
70
+ else:
71
+ out_size = int(np.prod([shape[i] for i in out_axis]))
72
+ receptive_field_size = shape.total / in_size / out_size
73
+ fan_in = in_size * receptive_field_size
74
+ fan_out = out_size * receptive_field_size
75
+ return fan_in, fan_out
76
+
77
+
78
+ def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_):
79
+ """Inlined JAX `nn.initializer.variance_scaling`."""
80
+
81
+ def init(key, shape, dtype=dtype):
82
+ return jnp.zeros(shape, dtype=dtype)
83
+ dtype = jax.dtypes.canonicalize_dtype(dtype)
84
+ shape = jax.core.as_named_shape(shape)
85
+ fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
86
+ if mode == "fan_in":
87
+ denominator = fan_in
88
+ elif mode == "fan_out":
89
+ denominator = fan_out
90
+ elif mode == "fan_avg":
91
+ denominator = (fan_in + fan_out) / 2
92
+ else:
93
+ raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
94
+ variance = jnp.array(scale / denominator, dtype=dtype)
95
+
96
+ if distribution == "truncated_normal":
97
+ # constant is stddev of standard normal truncated to (-2, 2)
98
+ stddev = jnp.sqrt(variance) / jnp.array(0.87962566103423978, dtype)
99
+ return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
100
+ elif distribution == "normal":
101
+ return random.normal(key, shape, dtype) * jnp.sqrt(variance)
102
+ elif distribution == "uniform":
103
+ return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
104
+ else:
105
+ raise ValueError("invalid distribution for variance scaling initializer: {}".format(distribution))
106
+
107
+ return init
108
+
109
+
110
+ # ------------------------------------------------------------------------------
111
+
112
+
113
+ def nd_dense_init(scale, mode, distribution):
114
+ """Initializer with in_axis, out_axis set at call time."""
115
+
116
+ def init_fn(key, shape, dtype, in_axis, out_axis):
117
+ fn = variance_scaling(scale, mode, distribution, in_axis, out_axis)
118
+ return fn(key, shape, dtype)
119
+
120
+ return init_fn
121
+
122
+
123
+ def dot_product_attention(
124
+ query: Array,
125
+ key: Array,
126
+ value: Array,
127
+ bias: Optional[Array] = None,
128
+ dropout_rng: Optional[PRNGKey] = None,
129
+ dropout_rate: float = 0.0,
130
+ deterministic: bool = False,
131
+ dtype: DType = jnp.float32,
132
+ float32_logits: bool = False,
133
+ ):
134
+ """Computes dot-product attention given query, key, and value.
135
+
136
+ This is the core function for applying attention based on
137
+ https://arxiv.org/abs/1706.03762. It calculates the attention weights given
138
+ query and key and combines the values using the attention weights.
139
+
140
+ Args:
141
+ query: queries for calculating attention with shape of `[batch, q_length,
142
+ num_heads, qk_depth_per_head]`.
143
+ key: keys for calculating attention with shape of `[batch, kv_length,
144
+ num_heads, qk_depth_per_head]`.
145
+ value: values to be used in attention with shape of `[batch, kv_length,
146
+ num_heads, v_depth_per_head]`.
147
+ bias: bias for the attention weights. This should be broadcastable to the
148
+ shape `[batch, num_heads, q_length, kv_length]` This can be used for
149
+ incorporating causal masks, padding masks, proximity bias, etc.
150
+ dropout_rng: JAX PRNGKey: to be used for dropout
151
+ dropout_rate: dropout rate
152
+ deterministic: bool, deterministic or not (to apply dropout)
153
+ dtype: the dtype of the computation (default: float32)
154
+ float32_logits: bool, if True then compute logits in float32 to avoid
155
+ numerical issues with bfloat16.
156
+
157
+ Returns:
158
+ Output of shape `[batch, length, num_heads, v_depth_per_head]`.
159
+ """
160
+ assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
161
+ assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match."
162
+ assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match."
163
+ assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
164
+ assert query.shape[-1] == key.shape[-1], "q, k depths must match."
165
+
166
+ # Casting logits and softmax computation for float32 for model stability.
167
+ if float32_logits:
168
+ query = query.astype(jnp.float32)
169
+ key = key.astype(jnp.float32)
170
+
171
+ # `attn_weights`: [batch, num_heads, q_length, kv_length]
172
+ attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
173
+
174
+ # Apply attention bias: masking, dropout, proximity bias, etc.
175
+ if bias is not None:
176
+ attn_weights = attn_weights + bias.astype(attn_weights.dtype)
177
+
178
+ # Normalize the attention weights across `kv_length` dimension.
179
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
180
+
181
+ # Apply attention dropout.
182
+ if not deterministic and dropout_rate > 0.0:
183
+ keep_prob = 1.0 - dropout_rate
184
+ # T5 broadcasts along the "length" dim, but unclear which one that
185
+ # corresponds to in positional dimensions here, assuming query dim.
186
+ dropout_shape = list(attn_weights.shape)
187
+ dropout_shape[-2] = 1
188
+ keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
189
+ keep = jnp.broadcast_to(keep, attn_weights.shape)
190
+ multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)
191
+ attn_weights = attn_weights * multiplier
192
+
193
+ # Take the linear combination of `value`.
194
+ return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
195
+
196
+
197
+ dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
198
+
199
+
200
+ class MultiHeadDotProductAttention(nn.Module):
201
+ """Multi-head dot-product attention.
202
+
203
+ Attributes:
204
+ num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
205
+ should be divisible by the number of heads.
206
+ head_dim: dimension of each head.
207
+ dtype: the dtype of the computation.
208
+ dropout_rate: dropout rate
209
+ kernel_init: initializer for the kernel of the Dense layers.
210
+ float32_logits: bool, if True then compute logits in float32 to avoid
211
+ numerical issues with bfloat16.
212
+ """
213
+
214
+ num_heads: int
215
+ head_dim: int
216
+ dtype: DType = jnp.float32
217
+ dropout_rate: float = 0.0
218
+ kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal")
219
+ float32_logits: bool = False # computes logits in float32 for stability.
220
+
221
+ @nn.compact
222
+ def __call__(
223
+ self,
224
+ inputs_q: Array,
225
+ inputs_kv: Array,
226
+ mask: Optional[Array] = None,
227
+ bias: Optional[Array] = None,
228
+ *,
229
+ decode: bool = False,
230
+ deterministic: bool = False,
231
+ ) -> Array:
232
+ """Applies multi-head dot product attention on the input data.
233
+
234
+ Projects the inputs into multi-headed query, key, and value vectors,
235
+ applies dot-product attention and project the results to an output vector.
236
+
237
+ There are two modes: decoding and non-decoding (e.g., training). The mode is
238
+ determined by `decode` argument. For decoding, this method is called twice,
239
+ first to initialize the cache and then for an actual decoding process. The
240
+ two calls are differentiated by the presence of 'cached_key' in the variable
241
+ dict. In the cache initialization stage, the cache variables are initialized
242
+ as zeros and will be filled in the subsequent decoding process.
243
+
244
+ In the cache initialization call, `inputs_q` has a shape [batch, length,
245
+ q_features] and `inputs_kv`: [batch, length, kv_features]. During the
246
+ incremental decoding stage, query, key and value all have the shape [batch,
247
+ 1, qkv_features] corresponding to a single step.
248
+
249
+ Args:
250
+ inputs_q: input queries of shape `[batch, q_length, q_features]`.
251
+ inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
252
+ mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
253
+ bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
254
+ decode: Whether to prepare and use an autoregressive cache.
255
+ deterministic: Disables dropout if set to True.
256
+
257
+ Returns:
258
+ output of shape `[batch, length, q_features]`.
259
+ """
260
+ projection = functools.partial(
261
+ DenseGeneral,
262
+ axis=-1,
263
+ features=(self.num_heads, self.head_dim),
264
+ kernel_axes=("embed", "heads", "kv"),
265
+ dtype=self.dtype,
266
+ )
267
+
268
+ # NOTE: T5 does not explicitly rescale the attention logits by
269
+ # 1/sqrt(depth_kq)! This is folded into the initializers of the
270
+ # linear transformations, which is equivalent under Adafactor.
271
+ depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
272
+
273
+ def query_init(*args):
274
+ return self.kernel_init(*args) / depth_scaling
275
+
276
+ # Project inputs_q to multi-headed q/k/v
277
+ # dimensions are then [batch, length, num_heads, head_dim]
278
+ query = projection(kernel_init=query_init, name="query")(inputs_q)
279
+ key = projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
280
+ value = projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
281
+
282
+ query = with_sharding_constraint(query, ("batch", "length", "heads", "kv"))
283
+ key = with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
284
+ value = with_sharding_constraint(value, ("batch", "length", "heads", "kv"))
285
+
286
+ if decode:
287
+ # Detect if we're initializing by absence of existing cache data.
288
+ is_initialized = self.has_variable("cache", "cached_key")
289
+
290
+ # The key and value have dimension [batch, length, num_heads, head_dim],
291
+ # but we cache them as [batch, num_heads, head_dim, length] as a TPU
292
+ # fusion optimization. This also enables the "scatter via one-hot
293
+ # broadcast" trick, which means we do a one-hot broadcast instead of a
294
+ # scatter/gather operations, resulting in a 3-4x speedup in practice.
295
+ def swap_dims(x):
296
+ return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
297
+
298
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype)
299
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype)
300
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
301
+ if is_initialized:
302
+ batch, num_heads, head_dim, length = cached_key.value.shape
303
+ # During fast autoregressive decoding, we feed one position at a time,
304
+ # and cache the keys and values step by step.
305
+ # Sanity shape check of cached key against input query.
306
+ expected_shape = (batch, 1, num_heads, head_dim)
307
+ if expected_shape != query.shape:
308
+ raise ValueError(
309
+ "Autoregressive cache shape error, "
310
+ "expected query shape %s instead got %s." % (expected_shape, query.shape)
311
+ )
312
+
313
+ # Create a OHE of the current index. NOTE: the index is increased below.
314
+ cur_index = cache_index.value
315
+ one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype)
316
+ # In order to update the key, value caches with the current key and
317
+ # value, we move the length axis to the back, similar to what we did for
318
+ # the cached ones above.
319
+ # Note these are currently the key and value of a single position, since
320
+ # we feed one position at a time.
321
+ one_token_key = jnp.moveaxis(key, -3, -1)
322
+ one_token_value = jnp.moveaxis(value, -3, -1)
323
+ # Update key, value caches with our new 1d spatial slices.
324
+ # We implement an efficient scatter into the cache via one-hot
325
+ # broadcast and addition.
326
+ key = cached_key.value + one_token_key * one_hot_indices
327
+ value = cached_value.value + one_token_value * one_hot_indices
328
+ cached_key.value = key
329
+ cached_value.value = value
330
+ cache_index.value = cache_index.value + 1
331
+ # Move the keys and values back to their original shapes.
332
+ key = jnp.moveaxis(key, -1, -3)
333
+ value = jnp.moveaxis(value, -1, -3)
334
+
335
+ # Causal mask for cached decoder self-attention: our single query
336
+ # position should only attend to those key positions that have already
337
+ # been generated and cached, not the remaining zero elements.
338
+ mask = combine_masks(
339
+ mask,
340
+ jnp.broadcast_to(
341
+ jnp.arange(length) <= cur_index,
342
+ # (1, 1, length) represent (head dim, query length, key length)
343
+ # query length is 1 because during decoding we deal with one
344
+ # index.
345
+ # The same mask is applied to all batch elements and heads.
346
+ (batch, 1, 1, length),
347
+ ),
348
+ )
349
+
350
+ # Grab the correct relative attention bias during decoding. This is
351
+ # only required during single step decoding.
352
+ if bias is not None:
353
+ # The bias is a full attention matrix, but during decoding we only
354
+ # have to take a slice of it.
355
+ # This is equivalent to bias[..., cur_index:cur_index+1, :].
356
+ bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2)
357
+
358
+ # Convert the boolean attention mask to an attention bias.
359
+ if mask is not None:
360
+ # attention mask in the form of attention bias
361
+ attention_bias = lax.select(
362
+ mask > 0,
363
+ jnp.full(mask.shape, 0.0).astype(self.dtype),
364
+ jnp.full(mask.shape, -1e10).astype(self.dtype),
365
+ )
366
+ else:
367
+ attention_bias = None
368
+
369
+ # Add provided bias term (e.g. relative position embedding).
370
+ if bias is not None:
371
+ attention_bias = combine_biases(attention_bias, bias)
372
+
373
+ dropout_rng = None
374
+ if not deterministic and self.dropout_rate > 0.0:
375
+ dropout_rng = self.make_rng("dropout")
376
+
377
+ # Apply attention.
378
+ x = dot_product_attention(
379
+ query,
380
+ key,
381
+ value,
382
+ bias=attention_bias,
383
+ dropout_rng=dropout_rng,
384
+ dropout_rate=self.dropout_rate,
385
+ deterministic=deterministic,
386
+ dtype=self.dtype,
387
+ float32_logits=self.float32_logits,
388
+ )
389
+
390
+ # Back to the original inputs dimensions.
391
+ out = DenseGeneral(
392
+ features=inputs_q.shape[-1], # output dim is set to the input dim.
393
+ axis=(-2, -1),
394
+ kernel_init=self.kernel_init,
395
+ kernel_axes=("heads", "kv", "embed"),
396
+ dtype=self.dtype,
397
+ name="out",
398
+ )(x)
399
+ return out
400
+
401
+
402
+ def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
403
+ # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
404
+ return tuple([ax if ax >= 0 else ndim + ax for ax in axes])
405
+
406
+
407
+ def _canonicalize_tuple(x):
408
+ if isinstance(x, Iterable):
409
+ return tuple(x)
410
+ else:
411
+ return (x,)
412
+
413
+
414
+ # ------------------------------------------------------------------------------
415
+ # DenseGeneral for attention layers.
416
+ # ------------------------------------------------------------------------------
417
+ class DenseGeneral(nn.Module):
418
+ """A linear transformation (without bias) with flexible axes.
419
+
420
+ Attributes:
421
+ features: tuple with numbers of output features.
422
+ axis: tuple with axes to apply the transformation on.
423
+ dtype: the dtype of the computation (default: float32).
424
+ kernel_init: initializer function for the weight matrix.
425
+ """
426
+
427
+ features: Union[Iterable[int], int]
428
+ axis: Union[Iterable[int], int] = -1
429
+ dtype: DType = jnp.float32
430
+ params_dtype: DType = jnp.float32
431
+ kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal")
432
+ kernel_axes: Tuple[str, ...] = ()
433
+ use_bias: bool = True
434
+ bias_init: Any = nn.initializers.zeros
435
+
436
+ @nn.compact
437
+ def __call__(self, inputs: Array) -> Array:
438
+ """Applies a linear transformation to the inputs along multiple dimensions.
439
+
440
+ Args:
441
+ inputs: The nd-array to be transformed.
442
+
443
+ Returns:
444
+ The transformed input.
445
+ """
446
+ features = _canonicalize_tuple(self.features)
447
+ axis = _canonicalize_tuple(self.axis)
448
+
449
+ inputs = jnp.asarray(inputs, self.dtype)
450
+ axis = _normalize_axes(axis, inputs.ndim)
451
+
452
+ kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
453
+ kernel_in_axis = np.arange(len(axis))
454
+ kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
455
+ kernel = param_with_axes(
456
+ "kernel",
457
+ self.kernel_init,
458
+ kernel_shape,
459
+ self.params_dtype,
460
+ kernel_in_axis,
461
+ kernel_out_axis,
462
+ axes=self.kernel_axes,
463
+ )
464
+ if self.use_bias:
465
+ bias = param_with_axes(
466
+ "bias",
467
+ self.bias_init,
468
+ features,
469
+ self.params_dtype,
470
+ axes=(self.kernel_axes[-1],),
471
+ )
472
+ kernel = jnp.asarray(kernel, self.dtype)
473
+
474
+ contract_ind = tuple(range(0, len(axis)))
475
+ y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
476
+ if self.use_bias:
477
+ bias = jnp.asarray(bias, self.dtype)
478
+ # y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
479
+ y += jnp.reshape(bias, (1,) * (len(features) - y.ndim) + bias.shape[:])
480
+ return y
481
+
482
+
483
+ def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
484
+ """Convert a string to an activation function."""
485
+ if fn_or_string == "linear":
486
+ return lambda x: x
487
+ elif isinstance(fn_or_string, str):
488
+ return getattr(nn, fn_or_string)
489
+ elif callable(fn_or_string):
490
+ return fn_or_string
491
+ else:
492
+ raise ValueError("don't know how to convert %s to an activation function" % (fn_or_string,))
493
+
494
+
495
+ class MlpBlock(nn.Module):
496
+ """Transformer MLP / feed-forward block.
497
+
498
+ Attributes:
499
+ intermediate_dim: Shared dimension of hidden layers.
500
+ activations: Type of activations for each layer. Each element is either
501
+ 'linear', a string function name in flax.linen, or a function.
502
+ kernel_init: Kernel function, passed to the dense layers.
503
+ deterministic: Whether the dropout layers should be deterministic.
504
+ intermediate_dropout_rate: Dropout rate used after the intermediate layers.
505
+ dtype: Type for the dense layer.
506
+ """
507
+
508
+ intermediate_dim: int = 2048
509
+ activations: Sequence[Union[str, Callable]] = ("relu",)
510
+ kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal")
511
+ intermediate_dropout_rate: float = 0.1
512
+ dtype: Any = jnp.float32
513
+
514
+ @nn.compact
515
+ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
516
+ """Applies Transformer MlpBlock module."""
517
+ # Iterate over specified MLP input activation functions.
518
+ # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.
519
+ activations = []
520
+ for idx, act_fn in enumerate(self.activations):
521
+ dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
522
+ x = DenseGeneral(
523
+ self.intermediate_dim,
524
+ dtype=self.dtype,
525
+ kernel_init=self.kernel_init,
526
+ kernel_axes=("embed", "mlp"),
527
+ name=dense_name,
528
+ )(inputs)
529
+ x = _convert_to_activation_function(act_fn)(x)
530
+ activations.append(x)
531
+
532
+ # Take elementwise product of above intermediate activations.
533
+ x = functools.reduce(operator.mul, activations)
534
+ # Apply dropout and final dense output projection.
535
+ x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
536
+ x, deterministic=deterministic
537
+ ) # Broadcast along length.
538
+ x = with_sharding_constraint(x, ("batch", "length", "mlp"))
539
+ output = DenseGeneral(
540
+ inputs.shape[-1],
541
+ dtype=self.dtype,
542
+ kernel_init=self.kernel_init,
543
+ kernel_axes=("mlp", "embed"),
544
+ name="wo",
545
+ )(x)
546
+ return output
547
+
548
+
549
+ class Embed(nn.Module):
550
+ """A parameterized function from integers [0, n) to d-dimensional vectors.
551
+
552
+ Attributes:
553
+ num_embeddings: number of embeddings.
554
+ features: number of feature dimensions for each embedding.
555
+ dtype: the dtype of the embedding vectors (default: float32).
556
+ embedding_init: embedding initializer.
557
+ one_hot: performs the gather with a one-hot contraction rather than a true
558
+ gather. This is currently needed for SPMD partitioning.
559
+ """
560
+
561
+ num_embeddings: int
562
+ features: int
563
+ cast_input_dtype: Optional[DType] = None
564
+ dtype: DType = jnp.float32
565
+ params_dtype: DType = jnp.float32
566
+ attend_dtype: Optional[DType] = None
567
+ embedding_init: Initializer = default_embed_init
568
+ one_hot: bool = True
569
+ embedding: Array = dataclasses.field(init=False)
570
+
571
+ def setup(self):
572
+ self.embedding = param_with_axes(
573
+ "embedding",
574
+ self.embedding_init,
575
+ (self.num_embeddings, self.features),
576
+ self.params_dtype,
577
+ axes=("vocab", "embed"),
578
+ )
579
+
580
+ def __call__(self, inputs: Array) -> Array:
581
+ """Embeds the inputs along the last dimension.
582
+
583
+ Args:
584
+ inputs: input data, all dimensions are considered batch dimensions.
585
+
586
+ Returns:
587
+ Output which is embedded input data. The output shape follows the input,
588
+ with an additional `features` dimension appended.
589
+ """
590
+ if self.cast_input_dtype:
591
+ inputs = inputs.astype(self.cast_input_dtype)
592
+ if not jnp.issubdtype(inputs.dtype, jnp.integer):
593
+ raise ValueError("Input type must be an integer or unsigned integer.")
594
+ if self.one_hot:
595
+ iota = lax.iota(jnp.int32, self.num_embeddings)
596
+ one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
597
+ output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype))
598
+ else:
599
+ output = jnp.asarray(self.embedding, self.dtype)[inputs]
600
+ output = with_sharding_constraint(output, ("batch", "length", "embed"))
601
+ return output
602
+
603
+ def attend(self, query: Array) -> Array:
604
+ """Attend over the embedding using a query array.
605
+
606
+ Args:
607
+ query: array with last dimension equal the feature depth `features` of the
608
+ embedding.
609
+
610
+ Returns:
611
+ An array with final dim `num_embeddings` corresponding to the batched
612
+ inner-product of the array of query vectors against each embedding.
613
+ Commonly used for weight-sharing between embeddings and logit transform
614
+ in NLP models.
615
+ """
616
+ dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype
617
+ return jnp.dot(query, jnp.asarray(self.embedding, dtype).T)
618
+
619
+
620
+ class RelativePositionBiases(nn.Module):
621
+ """Adds T5-style relative positional embeddings to the attention logits.
622
+
623
+ Attributes:
624
+ num_buckets: Number of buckets to bucket distances between key and query
625
+ positions into.
626
+ max_distance: Maximum distance before everything is lumped into the last
627
+ distance bucket.
628
+ num_heads: Number of heads in the attention layer. Each head will get a
629
+ different relative position weighting.
630
+ dtype: Type of arrays through this module.
631
+ embedding_init: initializer for relative embedding table.
632
+ """
633
+
634
+ num_buckets: int
635
+ max_distance: int
636
+ num_heads: int
637
+ dtype: Any
638
+ embedding_init: Callable[..., Array] = nn.linear.default_embed_init
639
+
640
+ @staticmethod
641
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
642
+ """Translate relative position to a bucket number for relative attention.
643
+
644
+ The relative position is defined as memory_position - query_position, i.e.
645
+ the distance in tokens from the attending position to the attended-to
646
+ position. If bidirectional=False, then positive relative positions are
647
+ invalid.
648
+ We use smaller buckets for small absolute relative_position and larger
649
+ buckets for larger absolute relative_positions. All relative
650
+ positions >=max_distance map to the same bucket. All relative
651
+ positions <=-max_distance map to the same bucket. This should allow for
652
+ more graceful generalization to longer sequences than the model has been
653
+ trained on.
654
+
655
+ Args:
656
+ relative_position: an int32 array
657
+ bidirectional: a boolean - whether the attention is bidirectional
658
+ num_buckets: an integer
659
+ max_distance: an integer
660
+
661
+ Returns:
662
+ a Tensor with the same shape as relative_position, containing int32
663
+ values in the range [0, num_buckets)
664
+ """
665
+ ret = 0
666
+ n = -relative_position
667
+ if bidirectional:
668
+ num_buckets //= 2
669
+ ret += (n < 0).astype(np.int32) * num_buckets
670
+ n = np.abs(n)
671
+ else:
672
+ n = np.maximum(n, 0)
673
+ # now n is in the range [0, inf)
674
+ max_exact = num_buckets // 2
675
+ is_small = n < max_exact
676
+ val_if_large = max_exact + (
677
+ np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
678
+ / np.log(max_distance / max_exact)
679
+ * (num_buckets - max_exact)
680
+ ).astype(np.int32)
681
+ val_if_large = np.minimum(val_if_large, num_buckets - 1)
682
+ ret += np.where(is_small, n, val_if_large)
683
+ return ret
684
+
685
+ @nn.compact
686
+ def __call__(self, qlen, klen, bidirectional=True):
687
+ """Produce relative position embedding attention biases.
688
+
689
+ Args:
690
+ qlen: attention query length.
691
+ klen: attention key length.
692
+ bidirectional: whether to allow positive memory-query relative position
693
+ embeddings.
694
+
695
+ Returns:
696
+ output: `(1, len, q_len, k_len)` attention bias
697
+ """
698
+ # TODO(levskaya): should we be computing this w. numpy as a program
699
+ # constant?
700
+ context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
701
+ memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
702
+ relative_position = memory_position - context_position # shape (qlen, klen)
703
+ rp_bucket = self._relative_position_bucket(
704
+ relative_position,
705
+ bidirectional=bidirectional,
706
+ num_buckets=self.num_buckets,
707
+ max_distance=self.max_distance,
708
+ )
709
+ relative_attention_bias = param_with_axes(
710
+ "rel_embedding",
711
+ self.embedding_init,
712
+ (self.num_heads, self.num_buckets),
713
+ jnp.float32,
714
+ axes=("heads", "relpos_buckets"),
715
+ )
716
+
717
+ relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
718
+ # Instead of using a slow gather, we create a leading-dimension one-hot
719
+ # array from rp_bucket and use it to perform the gather-equivalent via a
720
+ # contraction, i.e.:
721
+ # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen).
722
+ # This is equivalent to relative_attention_bias[:, rp_bucket]
723
+ bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
724
+ rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
725
+ # --> shape (qlen, klen, num_heads)
726
+ values = lax.dot_general(
727
+ relative_attention_bias,
728
+ rp_bucket_one_hot,
729
+ (((1,), (0,)), ((), ())), # rhs, lhs contracting dims
730
+ ) # no batched dims
731
+ # Add a singleton batch dimension.
732
+ # --> shape (1, num_heads, qlen, klen)
733
+ return values[jnp.newaxis, ...]
734
+
735
+
736
+ # ------------------------------------------------------------------------------
737
+ # T5 Layernorm - no subtraction of mean or bias.
738
+ # ------------------------------------------------------------------------------
739
+ # class LayerNorm(nn.Module):
740
+ # """T5 Layer normalization operating on the last axis of the input data."""
741
+ # epsilon: float = 1e-6
742
+ # dtype: Any = jnp.float32
743
+ # scale_init: Initializer = nn.initializers.ones
744
+
745
+ # @nn.compact
746
+ # def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
747
+ # """Applies layer normalization on the input."""
748
+ # x = jnp.asarray(x, jnp.float32)
749
+ # features = x.shape[-1]
750
+ # mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
751
+ # y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
752
+ # scale = param_with_axes(
753
+ # 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',))
754
+
755
+ # scale = jnp.asarray(scale, self.dtype)
756
+ # return y * scale
757
+
758
+
759
+ class LayerNorm(nn.Module):
760
+ """Layer normalization (https://arxiv.org/abs/1607.06450).
761
+ Operates on the last axis of the input data.
762
+ It normalizes the activations of the layer for each given example in a
763
+ batch independently, rather than across a batch like Batch Normalization.
764
+ i.e. applies a transformation that maintains the mean activation within
765
+ each example close to 0 and the activation standard deviation close to 1.
766
+ Attributes:
767
+ epsilon: A small float added to variance to avoid dividing by zero.
768
+ dtype: the dtype of the computation (default: float32).
769
+ use_bias: If True, bias (beta) is added.
770
+ use_scale: If True, multiply by scale (gamma). When the next layer is linear
771
+ (also e.g. nn.relu), this can be disabled since the scaling will be done
772
+ by the next layer.
773
+ bias_init: Initializer for bias, by default, zero.
774
+ scale_init: Initializer for scale, by default, one.
775
+ """
776
+
777
+ epsilon: float = 1e-6
778
+ dtype: Any = jnp.float32
779
+ params_dtype: DType = jnp.float32
780
+ use_bias: bool = True
781
+ use_scale: bool = True
782
+ bias_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.zeros
783
+ scale_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.ones
784
+
785
+ @nn.compact
786
+ def __call__(self, x):
787
+ """Applies layer normalization on the input.
788
+ Args:
789
+ x: the inputs
790
+ Returns:
791
+ Normalized inputs (the same shape as inputs).
792
+ """
793
+ x = jnp.asarray(x, jnp.float32)
794
+ features = x.shape[-1]
795
+ mean = jnp.mean(x, axis=-1, keepdims=True)
796
+ mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
797
+ var = mean2 - lax.square(mean)
798
+ mul = lax.rsqrt(var + self.epsilon)
799
+ if self.use_scale:
800
+ scale = param_with_axes(
801
+ "scale",
802
+ self.scale_init,
803
+ (features,),
804
+ self.params_dtype,
805
+ axes=("embed",),
806
+ )
807
+ mul = mul * jnp.asarray(scale, self.dtype)
808
+ y = (x - mean) * mul
809
+ if self.use_bias:
810
+ bias = param_with_axes("bias", self.bias_init, (features,), self.params_dtype, axes=("embed",))
811
+ y = y + jnp.asarray(bias, self.dtype)
812
+ return jnp.asarray(y, self.dtype)
813
+
814
+
815
+ # ------------------------------------------------------------------------------
816
+ # Mask-making utility functions.
817
+ # ------------------------------------------------------------------------------
818
+ def make_attention_mask(
819
+ query_input: Array,
820
+ key_input: Array,
821
+ pairwise_fn: Callable = jnp.multiply,
822
+ extra_batch_dims: int = 0,
823
+ dtype: DType = jnp.float32,
824
+ ) -> Array:
825
+ """Mask-making helper for attention weights.
826
+
827
+ In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the
828
+ attention weights will be `[batch, heads, len_q, len_kv]` and this
829
+ function will produce `[batch, 1, len_q, len_kv]`.
830
+
831
+ Args:
832
+ query_input: a batched, flat input of query_length size
833
+ key_input: a batched, flat input of key_length size
834
+ pairwise_fn: broadcasting elementwise comparison function
835
+ extra_batch_dims: number of extra batch dims to add singleton axes for, none
836
+ by default
837
+ dtype: mask return dtype
838
+
839
+ Returns:
840
+ A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention.
841
+ """
842
+ # [batch, len_q, len_kv]
843
+ mask = pairwise_fn(
844
+ # [batch, len_q] -> [batch, len_q, 1]
845
+ jnp.expand_dims(query_input, axis=-1),
846
+ # [batch, len_q] -> [batch, 1, len_kv]
847
+ jnp.expand_dims(key_input, axis=-2),
848
+ )
849
+
850
+ # [batch, 1, len_q, len_kv]. This creates the head dim.
851
+ mask = jnp.expand_dims(mask, axis=-3)
852
+ mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
853
+ return mask.astype(dtype)
854
+
855
+
856
+ def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.float32) -> Array:
857
+ """Make a causal mask for self-attention.
858
+
859
+ In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights
860
+ will be `[batch, heads, len, len]` and this function will produce a
861
+ causal mask of shape `[batch, 1, len, len]`.
862
+
863
+ Note that a causal mask does not depend on the values of x; it only depends on
864
+ the shape. If x has padding elements, they will not be treated in a special
865
+ manner.
866
+
867
+ Args:
868
+ x: input array of shape `[batch, len]`
869
+ extra_batch_dims: number of batch dims to add singleton axes for, none by
870
+ default
871
+ dtype: mask return dtype
872
+
873
+ Returns:
874
+ A `[batch, 1, len, len]` shaped causal mask for 1d attention.
875
+ """
876
+ idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
877
+ return make_attention_mask(idxs, idxs, jnp.greater_equal, extra_batch_dims=extra_batch_dims, dtype=dtype)
878
+
879
+
880
+ def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
881
+ """Combine attention masks.
882
+
883
+ Args:
884
+ *masks: set of attention mask arguments to combine, some can be None.
885
+ dtype: final mask dtype
886
+
887
+ Returns:
888
+ Combined mask, reduced by logical and, returns None if no masks given.
889
+ """
890
+ masks = [m for m in masks if m is not None]
891
+ if not masks:
892
+ return None
893
+ assert all(
894
+ (x.ndim == masks[0].ndim for x in masks)
895
+ ), f"masks must have same rank: {tuple((x.ndim for x in masks))}"
896
+ mask, *other_masks = masks
897
+ for other_mask in other_masks:
898
+ mask = jnp.logical_and(mask, other_mask)
899
+ return mask.astype(dtype)
900
+
901
+
902
+ def combine_biases(*masks: Optional[Array]):
903
+ """Combine attention biases.
904
+
905
+ Args:
906
+ *masks: set of attention bias arguments to combine, some can be None.
907
+
908
+ Returns:
909
+ Combined mask, reduced by summation, returns None if no masks given.
910
+ """
911
+ masks = [m for m in masks if m is not None]
912
+ if not masks:
913
+ return None
914
+ assert all(
915
+ (x.ndim == masks[0].ndim for x in masks)
916
+ ), f"masks must have same rank: {tuple((x.ndim for x in masks))}"
917
+ mask, *other_masks = masks
918
+ for other_mask in other_masks:
919
+ mask = mask + other_mask
920
+ return mask
921
+
922
+
923
+ def make_decoder_mask(
924
+ decoder_target_tokens: Array,
925
+ dtype: DType,
926
+ decoder_causal_attention: Optional[Array] = None,
927
+ decoder_segment_ids: Optional[Array] = None,
928
+ ) -> Array:
929
+ """Compute the self-attention mask for a decoder.
930
+
931
+ Decoder mask is formed by combining a causal mask, a padding mask and an
932
+ optional packing mask. If decoder_causal_attention is passed, it makes the
933
+ masking non-causal for positions that have value of 1.
934
+
935
+ A prefix LM is applied to a dataset which has a notion of "inputs" and
936
+ "targets", e.g., a machine translation task. The inputs and targets are
937
+ concatenated to form a new target. `decoder_target_tokens` is the concatenated
938
+ decoder output tokens.
939
+
940
+ The "inputs" portion of the concatenated sequence can attend to other "inputs"
941
+ tokens even for those at a later time steps. In order to control this
942
+ behavior, `decoder_causal_attention` is necessary. This is a binary mask with
943
+ a value of 1 indicating that the position belonged to "inputs" portion of the
944
+ original dataset.
945
+
946
+ Example:
947
+
948
+ Suppose we have a dataset with two examples.
949
+
950
+ ds = [{"inputs": [6, 7], "targets": [8]},
951
+ {"inputs": [3, 4], "targets": [5]}]
952
+
953
+ After the data preprocessing with packing, the two examples are packed into
954
+ one example with the following three fields (some fields are skipped for
955
+ simplicity).
956
+
957
+ decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]]
958
+ decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
959
+ decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]
960
+
961
+ where each array has [batch, length] shape with batch size being 1. Then,
962
+ this function computes the following mask.
963
+
964
+ mask = [[[[1, 1, 0, 0, 0, 0, 0],
965
+ [1, 1, 0, 0, 0, 0, 0],
966
+ [1, 1, 1, 0, 0, 0, 0],
967
+ [0, 0, 0, 1, 1, 0, 0],
968
+ [0, 0, 0, 1, 1, 0, 0],
969
+ [0, 0, 0, 1, 1, 1, 0],
970
+ [0, 0, 0, 0, 0, 0, 0]]]]
971
+
972
+ mask[b, 1, :, :] represents the mask for the example `b` in the batch.
973
+ Because mask is for a self-attention layer, the mask's shape is a square of
974
+ shape [query length, key length].
975
+
976
+ mask[b, 1, i, j] = 1 means that the query token at position i can attend to
977
+ the key token at position j.
978
+
979
+ Args:
980
+ decoder_target_tokens: decoder output tokens. [batch, length]
981
+ dtype: dtype of the output mask.
982
+ decoder_causal_attention: a binary mask indicating which position should
983
+ only attend to earlier positions in the sequence. Others will attend
984
+ bidirectionally. [batch, length]
985
+ decoder_segment_ids: decoder segmentation info for packed examples. [batch,
986
+ length]
987
+
988
+ Returns:
989
+ the combined decoder mask.
990
+ """
991
+ masks = []
992
+ # The same mask is applied to all attention heads. So the head dimension is 1,
993
+ # i.e., the mask will be broadcast along the heads dim.
994
+ # [batch, 1, length, length]
995
+ causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype)
996
+
997
+ # Positions with value 1 in `decoder_causal_attneition` can attend
998
+ # bidirectionally.
999
+ if decoder_causal_attention is not None:
1000
+ # [batch, 1, length, length]
1001
+ inputs_mask = make_attention_mask(
1002
+ decoder_causal_attention,
1003
+ decoder_causal_attention,
1004
+ jnp.logical_and,
1005
+ dtype=dtype,
1006
+ )
1007
+ masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype))
1008
+ else:
1009
+ masks.append(causal_mask)
1010
+
1011
+ # Padding mask.
1012
+ masks.append(make_attention_mask(decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype))
1013
+
1014
+ # Packing mask
1015
+ if decoder_segment_ids is not None:
1016
+ masks.append(make_attention_mask(decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype))
1017
+
1018
+ return combine_masks(*masks, dtype=dtype)
1019
+
1020
+
1021
+ def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
1022
+ """ "Canonicalizes conv padding to a jax.lax supported format."""
1023
+ if isinstance(padding, str):
1024
+ return padding
1025
+ if isinstance(padding, int):
1026
+ return [(padding, padding)] * rank
1027
+ if isinstance(padding, Sequence) and len(padding) == rank:
1028
+ new_pad = []
1029
+ for p in padding:
1030
+ if isinstance(p, int):
1031
+ new_pad.append((p, p))
1032
+ elif isinstance(p, tuple) and len(p) == 2:
1033
+ new_pad.append(p)
1034
+ else:
1035
+ break
1036
+ if len(new_pad) == rank:
1037
+ return new_pad
1038
+ raise ValueError(
1039
+ f"Invalid padding format: {padding}, should be str, int,"
1040
+ f" or a sequence of len {rank} where each element is an"
1041
+ " int or pair of ints."
1042
+ )
1043
+
1044
+
1045
+ def _conv_dimension_numbers(input_shape):
1046
+ """Computes the dimension numbers based on the input shape."""
1047
+ ndim = len(input_shape)
1048
+ lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
1049
+ rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
1050
+ out_spec = lhs_spec
1051
+ return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
1052
+
1053
+
1054
+ class _Conv(nn.Module):
1055
+ """Convolution Module wrapping `lax.conv_general_dilated[_local]`.
1056
+
1057
+ Attributes:
1058
+ features: number of convolution filters.
1059
+ kernel_size: shape of the convolutional kernel. For 1D convolution,
1060
+ the kernel size can be passed as an integer. For all other cases, it must
1061
+ be a sequence of integers.
1062
+ strides: an integer or a sequence of `n` integers, representing the
1063
+ inter-window strides (default: 1).
1064
+ padding: either the string `'SAME'`, the string `'VALID'`, the string
1065
+ `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
1066
+ high)` integer pairs that give the padding to apply before and after each
1067
+ spatial dimension. A single int is interpeted as applying the same padding
1068
+ in all dims and passign a single int in a sequence causes the same padding
1069
+ to be used on both sides. `'CAUSAL'` padding for a 1D convolution will
1070
+ left-pad the convolution axis, resulting in same-sized output.
1071
+ input_dilation: an integer or a sequence of `n` integers, giving the
1072
+ dilation factor to apply in each spatial dimension of `inputs`
1073
+ (default: 1). Convolution with input dilation `d` is equivalent to
1074
+ transposed convolution with stride `d`.
1075
+ kernel_dilation: an integer or a sequence of `n` integers, giving the
1076
+ dilation factor to apply in each spatial dimension of the convolution
1077
+ kernel (default: 1). Convolution with kernel dilation
1078
+ is also known as 'atrous convolution'.
1079
+ feature_group_count: integer, default 1. If specified divides the input
1080
+ features into groups.
1081
+ use_bias: whether to add a bias to the output (default: True).
1082
+ mask: Optional mask for the weights during masked convolution. The mask must
1083
+ be the same shape as the convolution weight matrix.
1084
+ dtype: the dtype of the computation (default: infer from input and params).
1085
+ params_dtype: the dtype passed to parameter initializers (default: float32).
1086
+ precision: numerical precision of the computation see `jax.lax.Precision`
1087
+ for details.
1088
+ kernel_init: initializer for the convolutional kernel.
1089
+ bias_init: initializer for the bias.
1090
+ """
1091
+
1092
+ features: int
1093
+ kernel_size: Sequence[int]
1094
+ strides: Union[None, int, Sequence[int]] = 1
1095
+ padding: PaddingLike = "SAME"
1096
+ input_dilation: Union[None, int, Sequence[int]] = 1
1097
+ kernel_dilation: Union[None, int, Sequence[int]] = 1
1098
+ feature_group_count: int = 1
1099
+ use_bias: bool = True
1100
+ mask: Optional[Array] = None
1101
+ dtype: Optional[DType] = None
1102
+ params_dtype: DType = jnp.float32
1103
+ precision: PrecisionLike = None
1104
+ kernel_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.lecun_normal()
1105
+ bias_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.zeros
1106
+ conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated
1107
+ kernel_axes: Tuple[str, ...] = ()
1108
+
1109
+ @property
1110
+ def shared_weights(self) -> bool: # type: ignore
1111
+ """Defines whether weights are shared or not between different pixels.
1112
+
1113
+ Returns:
1114
+ `True` to use shared weights in convolution (regular convolution).
1115
+ `False` to use different weights at different pixels, a.k.a.
1116
+ "locally connected layer", "unshared convolution", or "local convolution".
1117
+
1118
+ """
1119
+ ...
1120
+
1121
+ @nn.compact
1122
+ def __call__(self, inputs: Array) -> Array:
1123
+ """Applies a (potentially unshared) convolution to the inputs.
1124
+
1125
+ Args:
1126
+ inputs: input data with dimensions (*batch_dims, spatial_dims...,
1127
+ features). This is the channels-last convention, i.e. NHWC for a 2d
1128
+ convolution and NDHWC for a 3D convolution. Note: this is different from
1129
+ the input convention used by `lax.conv_general_dilated`, which puts the
1130
+ spatial dimensions last.
1131
+ Note: If the input has more than 1 batch dimension, all batch dimensions
1132
+ are flattened into a single dimension for the convolution and restored
1133
+ before returning. In some cases directly vmap'ing the layer may yield
1134
+ better performance than this default flattening approach. If the input
1135
+ lacks a batch dimension it will be added for the convolution and removed
1136
+ n return, an allowance made to enable writing single-example code.
1137
+
1138
+ Returns:
1139
+ The convolved data.
1140
+ """
1141
+
1142
+ if isinstance(self.kernel_size, int):
1143
+ raise TypeError(
1144
+ "Expected Conv kernel_size to be a"
1145
+ " tuple/list of integers (eg.: [3, 3]) but got"
1146
+ f" {self.kernel_size}."
1147
+ )
1148
+ else:
1149
+ kernel_size = tuple(self.kernel_size)
1150
+
1151
+ def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> Tuple[int, ...]:
1152
+ if x is None:
1153
+ # backward compatibility with using None as sentinel for
1154
+ # broadcast 1
1155
+ x = 1
1156
+ if isinstance(x, int):
1157
+ return (x,) * len(kernel_size)
1158
+ return tuple(x)
1159
+
1160
+ # Combine all input batch dimensions into a single leading batch axis.
1161
+ num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1)
1162
+ if num_batch_dimensions != 1:
1163
+ input_batch_shape = inputs.shape[:num_batch_dimensions]
1164
+ total_batch_size = int(np.prod(input_batch_shape))
1165
+ flat_input_shape = (total_batch_size,) + inputs.shape[num_batch_dimensions:]
1166
+ inputs = jnp.reshape(inputs, flat_input_shape)
1167
+
1168
+ # self.strides or (1,) * (inputs.ndim - 2)
1169
+ strides = maybe_broadcast(self.strides)
1170
+ input_dilation = maybe_broadcast(self.input_dilation)
1171
+ kernel_dilation = maybe_broadcast(self.kernel_dilation)
1172
+
1173
+ padding_lax = canonicalize_padding(self.padding, len(kernel_size))
1174
+ if padding_lax == "CIRCULAR":
1175
+ kernel_size_dilated = [(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)]
1176
+ zero_pad: List[Tuple[int, int]] = [(0, 0)]
1177
+ pads = zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)]
1178
+ inputs = jnp.pad(inputs, pads, mode="wrap")
1179
+ padding_lax = "VALID"
1180
+ elif padding_lax == "CAUSAL":
1181
+ if len(kernel_size) != 1:
1182
+ raise ValueError("Causal padding is only implemented for 1D convolutions.")
1183
+ left_pad = kernel_dilation[0] * (kernel_size[0] - 1)
1184
+ pads = [(0, 0), (left_pad, 0), (0, 0)]
1185
+ inputs = jnp.pad(inputs, pads)
1186
+ padding_lax = "VALID"
1187
+
1188
+ dimension_numbers = _conv_dimension_numbers(inputs.shape)
1189
+ in_features = jnp.shape(inputs)[-1]
1190
+
1191
+ if self.shared_weights:
1192
+ # One shared convolutional kernel for all pixels in the output.
1193
+ assert in_features % self.feature_group_count == 0
1194
+ kernel_shape = kernel_size + (
1195
+ in_features // self.feature_group_count,
1196
+ self.features,
1197
+ )
1198
+
1199
+ else:
1200
+ if self.feature_group_count != 1:
1201
+ raise NotImplementedError(
1202
+ "`lax.conv_general_dilated_local` does not support "
1203
+ f"`feature_group_count != 1`, got `{self.feature_group_count}`."
1204
+ )
1205
+
1206
+ # Need to know the spatial output shape of a standard convolution to
1207
+ # create the unshared convolution kernel.
1208
+ conv_output_shape = jax.eval_shape(
1209
+ lambda lhs, rhs: self.conv_general_dilated( # pylint: disable=g-long-lambda
1210
+ lhs=lhs,
1211
+ rhs=rhs,
1212
+ window_strides=strides,
1213
+ padding=padding_lax,
1214
+ dimension_numbers=dimension_numbers,
1215
+ lhs_dilation=input_dilation,
1216
+ rhs_dilation=kernel_dilation,
1217
+ ),
1218
+ inputs,
1219
+ jax.ShapedArray(kernel_size + (in_features, self.features), inputs.dtype),
1220
+ ).shape
1221
+
1222
+ # One (unshared) convolutional kernel per each pixel in the output.
1223
+ kernel_shape = conv_output_shape[1:-1] + (
1224
+ np.prod(kernel_size) * in_features,
1225
+ self.features,
1226
+ )
1227
+
1228
+ if self.mask is not None and self.mask.shape != kernel_shape:
1229
+ raise ValueError(
1230
+ "Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {kernel_shape}"
1231
+ )
1232
+
1233
+ kernel = param_with_axes(
1234
+ "kernel",
1235
+ self.kernel_init,
1236
+ kernel_shape,
1237
+ self.params_dtype,
1238
+ axes=self.kernel_axes,
1239
+ )
1240
+
1241
+ if self.mask is not None:
1242
+ kernel *= self.mask
1243
+
1244
+ if self.use_bias:
1245
+ if self.shared_weights:
1246
+ # One bias weight per output channel, shared between pixels.
1247
+ bias_shape = (self.features,)
1248
+ else:
1249
+ # One bias weight per output entry, unshared betwen pixels.
1250
+ bias_shape = conv_output_shape[1:]
1251
+
1252
+ bias = param_with_axes(
1253
+ "bias",
1254
+ self.bias_init,
1255
+ bias_shape,
1256
+ self.params_dtype,
1257
+ axes=(self.kernel_axes[-1],),
1258
+ )
1259
+ else:
1260
+ bias = None
1261
+
1262
+ inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
1263
+ if self.shared_weights:
1264
+ y = self.conv_general_dilated(
1265
+ inputs,
1266
+ kernel,
1267
+ strides,
1268
+ padding_lax,
1269
+ lhs_dilation=input_dilation,
1270
+ rhs_dilation=kernel_dilation,
1271
+ dimension_numbers=dimension_numbers,
1272
+ feature_group_count=self.feature_group_count,
1273
+ precision=self.precision,
1274
+ )
1275
+ else:
1276
+ y = lax.conv_general_dilated_local(
1277
+ lhs=inputs,
1278
+ rhs=kernel,
1279
+ window_strides=strides,
1280
+ padding=padding_lax,
1281
+ filter_shape=kernel_size,
1282
+ lhs_dilation=input_dilation,
1283
+ rhs_dilation=kernel_dilation,
1284
+ dimension_numbers=dimension_numbers,
1285
+ precision=self.precision,
1286
+ )
1287
+
1288
+ if self.use_bias:
1289
+ bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape)
1290
+ y += bias
1291
+
1292
+ if num_batch_dimensions != 1:
1293
+ output_shape = input_batch_shape + y.shape[1:]
1294
+ y = jnp.reshape(y, output_shape)
1295
+ return y
1296
+
1297
+
1298
+ class Conv(_Conv):
1299
+ """Convolution Module wrapping `lax.conv_general_dilated`.
1300
+
1301
+ Attributes:
1302
+ features: number of convolution filters.
1303
+ kernel_size: shape of the convolutional kernel. For 1D convolution,
1304
+ the kernel size can be passed as an integer. For all other cases, it must
1305
+ be a sequence of integers.
1306
+ strides: an integer or a sequence of `n` integers, representing the
1307
+ inter-window strides (default: 1).
1308
+ padding: either the string `'SAME'`, the string `'VALID'`, the string
1309
+ `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
1310
+ high)` integer pairs that give the padding to apply before and after each
1311
+ spatial dimension. A single int is interpeted as applying the same padding
1312
+ in all dims and passign a single int in a sequence causes the same padding
1313
+ to be used on both sides. `'CAUSAL'` padding for a 1D convolution will
1314
+ left-pad the convolution axis, resulting in same-sized output.
1315
+ input_dilation: an integer or a sequence of `n` integers, giving the
1316
+ dilation factor to apply in each spatial dimension of `inputs`
1317
+ (default: 1). Convolution with input dilation `d` is equivalent to
1318
+ transposed convolution with stride `d`.
1319
+ kernel_dilation: an integer or a sequence of `n` integers, giving the
1320
+ dilation factor to apply in each spatial dimension of the convolution
1321
+ kernel (default: 1). Convolution with kernel dilation
1322
+ is also known as 'atrous convolution'.
1323
+ feature_group_count: integer, default 1. If specified divides the input
1324
+ features into groups.
1325
+ use_bias: whether to add a bias to the output (default: True).
1326
+ mask: Optional mask for the weights during masked convolution. The mask must
1327
+ be the same shape as the convolution weight matrix.
1328
+ dtype: the dtype of the computation (default: infer from input and params).
1329
+ params_dtype: the dtype passed to parameter initializers (default: float32).
1330
+ precision: numerical precision of the computation see `jax.lax.Precision`
1331
+ for details.
1332
+ kernel_init: initializer for the convolutional kernel.
1333
+ bias_init: initializer for the bias.
1334
+ """
1335
+
1336
+ @property
1337
+ def shared_weights(self) -> bool:
1338
+ return True
flax/distil_whisper/modeling_flax_whisper.py ADDED
@@ -0,0 +1,2135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax whisper model."""
16
+
17
+ import random
18
+ from functools import partial
19
+ from typing import Dict, Optional, Tuple, Union
20
+
21
+ import flax.linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
25
+ from flax.linen import combine_masks, make_causal_mask
26
+ from flax.linen.attention import dot_product_attention_weights
27
+ from flax.linen.partitioning import remat, scan_with_axes
28
+ from flax.traverse_util import flatten_dict, unflatten_dict
29
+ from jax import lax
30
+ from jax.random import PRNGKey
31
+ from transformers import WhisperConfig
32
+ from transformers.generation.flax_logits_process import (
33
+ FlaxLogitsProcessor,
34
+ FlaxLogitsProcessorList,
35
+ FlaxWhisperTimeStampLogitsProcessor,
36
+ )
37
+ from transformers.modeling_flax_outputs import (
38
+ FlaxBaseModelOutput,
39
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
40
+ FlaxCausalLMOutputWithCrossAttentions,
41
+ FlaxSeq2SeqLMOutput,
42
+ FlaxSeq2SeqModelOutput,
43
+ )
44
+ from transformers.modeling_flax_utils import (
45
+ ACT2FN,
46
+ FlaxPreTrainedModel,
47
+ append_call_sample_docstring,
48
+ append_replace_return_docstrings,
49
+ overwrite_call_docstring,
50
+ )
51
+ from transformers.utils import (
52
+ add_start_docstrings,
53
+ add_start_docstrings_to_model_forward,
54
+ logging,
55
+ replace_return_docstrings,
56
+ )
57
+
58
+ from .layers import Conv, DenseGeneral, Embed, LayerNorm, with_sharding_constraint
59
+
60
+
61
+ logger = logging.get_logger(__name__)
62
+
63
+
64
+ _CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
65
+ _CONFIG_FOR_DOC = "WhisperConfig"
66
+
67
+
68
+ WHISPER_START_DOCSTRING = r"""
69
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
70
+ library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
71
+ etc.) This model is also a Flax Linen
72
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
73
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
74
+ Finally, this model supports inherent JAX features such as:
75
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
76
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
77
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
78
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
79
+
80
+ Parameters:
81
+ config ([`WhisperConfig`]): Model configuration class with all the parameters of the model.
82
+ Initializing with a config file does not load the weights associated with the model, only the
83
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
84
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
85
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
86
+ `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision
87
+ inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`.
88
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
89
+ parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`]
90
+ and [`~FlaxPreTrainedModel.to_bf16`].
91
+ """
92
+
93
+ WHISPER_INPUTS_DOCSTRING = r"""
94
+ Args:
95
+ input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
96
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
97
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
98
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
99
+ [`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a
100
+ tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]
101
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
102
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
103
+ is not used. By default the silence in the input log mel spectrogram are ignored.
104
+ decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
105
+ Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
106
+ [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
107
+ [What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as
108
+ the starting token for `decoder_input_ids` generation.
109
+ decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
110
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
111
+ be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
112
+ in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
113
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
114
+ Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't
115
+ use masking, but this argument is preserved for compatibility. By default the silence in the input log mel
116
+ spectrogram are ignored.
117
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
118
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
119
+ range `[0, config.max_position_embeddings - 1]`.
120
+ output_attentions (`bool`, *optional*):
121
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
122
+ tensors for more detail.
123
+ output_hidden_states (`bool`, *optional*):
124
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
125
+ more detail.
126
+ return_dict (`bool`, *optional*):
127
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
128
+ """
129
+
130
+ WHISPER_ENCODE_INPUTS_DOCSTRING = r"""
131
+ Args:
132
+ input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
133
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
134
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
135
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
136
+ [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
137
+ tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`].
138
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
139
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
140
+ is not used. By default the silence in the input log mel spectrogram are ignored.
141
+ output_attentions (`bool`, *optional*):
142
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
143
+ tensors for more detail.
144
+ output_hidden_states (`bool`, *optional*):
145
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
146
+ more detail.
147
+ return_dict (`bool`, *optional*):
148
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
149
+ """
150
+
151
+ WHISPER_DECODE_INPUTS_DOCSTRING = r"""
152
+ Args:
153
+ decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`):
154
+ Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
155
+ [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
156
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
157
+ encoder_outputs (`tuple(tuple(numpy.ndarray)`):
158
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
159
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
160
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
161
+ encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
162
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
163
+ but it is not used. By default the silence in the input log mel spectrogram are ignored.
164
+ decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
165
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
166
+ be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
167
+ in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
168
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
169
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
170
+ range `[0, config.max_position_embeddings - 1]`.
171
+ past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
172
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
173
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
174
+ output_attentions (`bool`, *optional*):
175
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
176
+ tensors for more detail.
177
+ output_hidden_states (`bool`, *optional*):
178
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
179
+ more detail.
180
+ return_dict (`bool`, *optional*):
181
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
182
+ """
183
+
184
+
185
+ class FlaxStaticForceTokensLogitsProcessor(FlaxLogitsProcessor):
186
+ r"""
187
+ [`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
188
+ token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
189
+ to `-inf` so that they are sampled at their corresponding index. This is a static version of the `transformers` logit
190
+ processor [`FlaxForceTokensLogitsProcessor`] that is compatible with sharded forced tokens.
191
+
192
+ Args:
193
+ force_token_map (`list`):
194
+ Map giving token ids and indices where they will be forced to be sampled.
195
+ """
196
+
197
+ def __init__(self, force_token_map):
198
+ # The generic `transformers` logit processor builds `force_token_array` as a dictionary - this is not a valid
199
+ # JAX type, and so we switch to using a JAX array instead
200
+ force_token_map = jnp.array(force_token_map)
201
+ # Converts the array of format [[index, token]] containing the tokens to be forced to an array, where the
202
+ # index of the array corresponds to the index of the token to be forced. For XLA compatibility,
203
+ # indexes without forced tokens will have a negative value. Note that the last token we ever need to force in
204
+ # Whisper is at position 3, so we only construct an array up to this index. The native version constructs a tensor
205
+ # dynamically according to the length of the `force_token_map`. Array shapes need to be concrete for XLA compatibility,
206
+ # so this is not permitted here.
207
+ force_token_array = jnp.ones(3, dtype=jnp.int32) * -1
208
+ for index, token in force_token_map:
209
+ force_token_array = force_token_array.at[index].set(token)
210
+ self.force_token_array = jnp.int32(force_token_array)
211
+
212
+ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
213
+ def _force_token(generation_idx):
214
+ batch_size = scores.shape[0]
215
+ current_token = self.force_token_array[generation_idx]
216
+
217
+ new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
218
+ updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
219
+ new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
220
+ return new_scores
221
+
222
+ scores = lax.cond(
223
+ cur_len >= self.force_token_array.shape[0],
224
+ # If the current length is geq than the length of force_token_array, the processor does nothing.
225
+ lambda: scores,
226
+ # Otherwise, it may force a certain token.
227
+ lambda: lax.cond(
228
+ self.force_token_array[cur_len] >= 0,
229
+ # Only valid (positive) tokens are forced
230
+ lambda: _force_token(cur_len),
231
+ # Otherwise, the processor does nothing.
232
+ lambda: scores,
233
+ ),
234
+ )
235
+ return scores
236
+
237
+
238
+ class FlaxWhisperAttention(nn.Module):
239
+ config: WhisperConfig
240
+ embed_dim: int
241
+ num_heads: int
242
+ dropout: float = 0.0
243
+ causal: bool = False
244
+ bias: bool = True
245
+ dtype: jnp.dtype = jnp.float32
246
+ params_dtype: jnp.dtype = jnp.float32
247
+
248
+ def setup(self) -> None:
249
+ self.head_dim = self.embed_dim // self.num_heads
250
+ if self.head_dim * self.num_heads != self.embed_dim:
251
+ raise ValueError(
252
+ "embed_dim must be divisible by num_heads (got `embed_dim`:"
253
+ f" {self.embed_dim} and `num_heads`: {self.num_heads})."
254
+ )
255
+
256
+ dense = partial(
257
+ DenseGeneral,
258
+ self.embed_dim,
259
+ axis=-1,
260
+ dtype=self.dtype,
261
+ params_dtype=self.params_dtype,
262
+ kernel_axes=("embed", "joined_kv"),
263
+ )
264
+
265
+ self.q_proj = dense(use_bias=self.bias)
266
+ self.k_proj = dense(use_bias=False)
267
+ self.v_proj = dense(use_bias=self.bias)
268
+
269
+ self.out_proj = DenseGeneral(
270
+ self.embed_dim,
271
+ axis=-1,
272
+ dtype=self.dtype,
273
+ params_dtype=self.params_dtype,
274
+ kernel_axes=("joined_kv", "embed"),
275
+ use_bias=self.bias,
276
+ )
277
+
278
+ if self.causal:
279
+ self.causal_mask = make_causal_mask(
280
+ jnp.ones((1, self.config.max_target_positions), dtype="bool"),
281
+ dtype="bool",
282
+ )
283
+
284
+ def __call__(
285
+ self,
286
+ hidden_states: jnp.ndarray,
287
+ key_value_states: Optional[jnp.ndarray] = None,
288
+ attention_mask: Optional[jnp.ndarray] = None,
289
+ init_cache: bool = False,
290
+ deterministic: bool = True,
291
+ ) -> Tuple[jnp.ndarray]:
292
+ is_cross_attention = key_value_states is not None
293
+ batch_size = hidden_states.shape[0]
294
+
295
+ query_states = self.q_proj(hidden_states)
296
+
297
+ if is_cross_attention:
298
+ key_states = self.k_proj(key_value_states)
299
+ value_states = self.v_proj(key_value_states)
300
+ else:
301
+ key_states = self.k_proj(hidden_states)
302
+ value_states = self.v_proj(hidden_states)
303
+
304
+ query_states = self._split_heads(query_states)
305
+ key_states = self._split_heads(key_states)
306
+ value_states = self._split_heads(value_states)
307
+
308
+ query_states = with_sharding_constraint(query_states, ("batch", "length", "heads", "kv"))
309
+ key_states = with_sharding_constraint(key_states, ("batch", "length", "heads", "kv"))
310
+ value_states = with_sharding_constraint(value_states, ("batch", "length", "heads", "kv"))
311
+
312
+ if self.causal:
313
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
314
+ if self.has_variable("cache", "cached_key"):
315
+ mask_shift = self.variables["cache"]["cache_index"]
316
+ # max_length of cached_key is last dim
317
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[-1]
318
+ causal_mask = lax.dynamic_slice(
319
+ self.causal_mask,
320
+ (0, 0, mask_shift, 0),
321
+ (1, 1, query_length, max_decoder_length),
322
+ )
323
+ else:
324
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
325
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
326
+
327
+ # combine masks if needed
328
+ if attention_mask is not None and self.causal:
329
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
330
+ attention_mask = combine_masks(attention_mask, causal_mask)
331
+ elif self.causal:
332
+ attention_mask = causal_mask
333
+ elif attention_mask is not None:
334
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
335
+
336
+ # During fast autoregressive decoding, we feed one position at a time,
337
+ # and cache the keys and values step by step.
338
+
339
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
340
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
341
+ key_states, value_states, query_states, attention_mask
342
+ )
343
+
344
+ # Convert the boolean attention mask to an attention bias.
345
+ if attention_mask is not None:
346
+ # attention mask in the form of attention bias
347
+ attention_bias = lax.select(
348
+ attention_mask > 0,
349
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
350
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
351
+ )
352
+ else:
353
+ attention_bias = None
354
+
355
+ dropout_rng = None
356
+ if not deterministic and self.dropout > 0.0:
357
+ dropout_rng = self.make_rng("dropout")
358
+
359
+ attn_weights = dot_product_attention_weights(
360
+ query_states,
361
+ key_states,
362
+ bias=attention_bias,
363
+ dropout_rng=dropout_rng,
364
+ dropout_rate=self.dropout,
365
+ broadcast_dropout=True,
366
+ deterministic=deterministic,
367
+ dtype=self.dtype,
368
+ precision=None,
369
+ )
370
+
371
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
372
+ attn_output = self._merge_heads(attn_output)
373
+ attn_output = self.out_proj(attn_output)
374
+
375
+ return attn_output, attn_weights
376
+
377
+ def _split_heads(self, hidden_state) -> jnp.ndarray:
378
+ return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim))
379
+
380
+ def _merge_heads(self, hidden_state) -> jnp.ndarray:
381
+ return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,))
382
+
383
+ @nn.compact
384
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
385
+ # The following code is largely copied from: https://github.com/google-research/t5x/blob/63d9addf628c6d8c547a407a32095fcb527bb20b/t5x/examples/scalable_t5/layers.py#L280-L284
386
+ is_initialized = self.has_variable("cache", "cached_key")
387
+
388
+ # The key and value have dimension [batch_size, seq_length, num_heads, head_dim],
389
+ # but we cache them as [batch_size, num_heads, head_dim, seq_length] as a TPU
390
+ # fusion optimization. This also enables the "scatter via one-hot
391
+ # broadcast" trick, which means we do a one-hot broadcast instead of a
392
+ # scatter/gather operations, resulting in a 3-4x speedup in practice.
393
+ def swap_dims(x):
394
+ return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
395
+
396
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype)
397
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype)
398
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
399
+
400
+ if is_initialized:
401
+ batch_size, num_heads, head_dim, seq_length = cached_key.value.shape
402
+ # During fast autoregressive decoding, we feed one position at a time,
403
+ # and cache the keys and values step by step.
404
+ # Sanity shape check of cached key against input query.
405
+ num_updated_cache_vectors = query.shape[1]
406
+ expected_shape = (batch_size, 1, num_heads, head_dim)
407
+ if num_updated_cache_vectors == 1 and expected_shape != query.shape:
408
+ raise ValueError(
409
+ "Autoregressive cache shape error, expected query shape"
410
+ f" {expected_shape} instead got {query.shape}"
411
+ )
412
+
413
+ # Create a OHE of the current index. NOTE: the index is increased below.
414
+ cur_index = cache_index.value
415
+
416
+ # In order to update the key, value caches with the current key and
417
+ # value, we move the seq_length axis to the back, similar to what we did for
418
+ # the cached ones above.
419
+ # Note these are currently the key and value of a single position, since
420
+ # we feed one position at a time.
421
+ one_token_key = jnp.moveaxis(key, -3, -1)
422
+ one_token_value = jnp.moveaxis(value, -3, -1)
423
+
424
+ # Update key, value caches with our new 1d spatial slices.
425
+ # We implement an efficient scatter into the cache via one-hot
426
+ # broadcast and addition.
427
+ if num_updated_cache_vectors > 1:
428
+ indices = jnp.eye(num_updated_cache_vectors, seq_length)[None, None]
429
+ key = cached_key.value + jnp.matmul(one_token_key, indices)
430
+ value = cached_value.value + jnp.matmul(one_token_value, indices)
431
+ else:
432
+ one_hot_indices = jax.nn.one_hot(cur_index, seq_length, dtype=key.dtype)
433
+ key = cached_key.value + one_token_key * one_hot_indices
434
+ value = cached_value.value + one_token_value * one_hot_indices
435
+
436
+ cached_key.value = key
437
+ cached_value.value = value
438
+ cache_index.value = cache_index.value + num_updated_cache_vectors
439
+
440
+ # Move the keys and values back to their original shapes.
441
+ key = jnp.moveaxis(key, -1, -3)
442
+ value = jnp.moveaxis(value, -1, -3)
443
+
444
+ # causal mask for cached decoder self-attention: our single query position should only
445
+ # attend to those key positions that have already been generated and cached, not the
446
+ # remaining zero elements.
447
+ pad_mask = jnp.broadcast_to(
448
+ jnp.arange(seq_length) < cur_index + num_updated_cache_vectors,
449
+ (batch_size,) + (1, num_updated_cache_vectors, seq_length),
450
+ )
451
+ attention_mask = combine_masks(pad_mask, attention_mask)
452
+
453
+ return key, value, attention_mask
454
+
455
+
456
+ class FlaxWhisperEncoderLayer(nn.Module):
457
+ config: WhisperConfig
458
+ dtype: jnp.dtype = jnp.float32
459
+ params_dtype: jnp.dtype = jnp.float32
460
+ use_scan: bool = False
461
+
462
+ def setup(self) -> None:
463
+ self.embed_dim = self.config.d_model
464
+ self.self_attn = FlaxWhisperAttention(
465
+ config=self.config,
466
+ embed_dim=self.embed_dim,
467
+ num_heads=self.config.encoder_attention_heads,
468
+ dropout=self.config.attention_dropout,
469
+ dtype=self.dtype,
470
+ params_dtype=self.params_dtype,
471
+ )
472
+ self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
473
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
474
+ self.activation_fn = ACT2FN[self.config.activation_function]
475
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
476
+ self.fc1 = DenseGeneral(
477
+ self.config.encoder_ffn_dim,
478
+ dtype=self.dtype,
479
+ params_dtype=self.params_dtype,
480
+ kernel_axes=("embed", "mlp"),
481
+ )
482
+ self.fc2 = DenseGeneral(
483
+ self.embed_dim,
484
+ dtype=self.dtype,
485
+ params_dtype=self.params_dtype,
486
+ kernel_axes=("mlp", "embed"),
487
+ )
488
+ self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
489
+
490
+ def __call__(
491
+ self,
492
+ hidden_states: jnp.ndarray,
493
+ attention_mask: jnp.ndarray,
494
+ output_attentions: bool = True,
495
+ deterministic: bool = True,
496
+ all_hidden_states=None, # only used when `use_scan=True` -> we have to fetch the hidden states from within the layer
497
+ ) -> Tuple[jnp.ndarray]:
498
+ if self.use_scan:
499
+ hidden_states = hidden_states[0]
500
+
501
+ hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
502
+
503
+ residual = hidden_states
504
+
505
+ layernorm_output = self.self_attn_layer_norm(hidden_states)
506
+ layernorm_output = with_sharding_constraint(layernorm_output, ("batch", "length", "embed"))
507
+
508
+ attn_output, attn_weights = self.self_attn(hidden_states=layernorm_output, attention_mask=attention_mask)
509
+ attn_output = self.dropout_layer(attn_output, deterministic=deterministic)
510
+ attn_output = residual + attn_output
511
+ attn_output = with_sharding_constraint(attn_output, ("batch", "length", "embed"))
512
+
513
+ residual = attn_output
514
+
515
+ post_layer_norm = self.final_layer_norm(attn_output)
516
+ post_layer_norm = with_sharding_constraint(post_layer_norm, ("batch", "length", "embed"))
517
+
518
+ fc1_output = self.activation_fn(self.fc1(post_layer_norm))
519
+ fc1_output = self.activation_dropout_layer(fc1_output, deterministic=deterministic)
520
+ fc1_output = with_sharding_constraint(fc1_output, ("batch", "length", "mlp"))
521
+
522
+ hidden_states = self.fc2(fc1_output)
523
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
524
+ hidden_states = residual + hidden_states
525
+ hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
526
+
527
+ outputs = (hidden_states,)
528
+
529
+ if output_attentions:
530
+ outputs += (attn_weights,)
531
+
532
+ if self.use_scan:
533
+ if all_hidden_states is not None:
534
+ all_hidden_states = all_hidden_states + (hidden_states,)
535
+ outputs = (
536
+ outputs,
537
+ all_hidden_states,
538
+ )
539
+
540
+ return outputs
541
+
542
+
543
+ class FlaxWhisperEncoderLayerCollection(nn.Module):
544
+ config: WhisperConfig
545
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
546
+ params_dtype: jnp.dtype = jnp.float32
547
+ use_scan: bool = False
548
+ gradient_checkpointing: bool = False
549
+
550
+ @nn.compact
551
+ def __call__(
552
+ self,
553
+ hidden_states,
554
+ attention_mask,
555
+ deterministic: bool = True,
556
+ output_attentions: bool = False,
557
+ output_hidden_states: bool = False,
558
+ return_dict: bool = True,
559
+ ):
560
+ all_attentions = () if output_attentions else None
561
+ all_hidden_states = () if output_hidden_states else None
562
+
563
+ FlaxWhisperEncoderCheckpointLayer = (
564
+ remat(
565
+ FlaxWhisperEncoderLayer,
566
+ static_argnums=(2, 3),
567
+ prevent_cse=not self.use_scan,
568
+ )
569
+ if self.gradient_checkpointing
570
+ else FlaxWhisperEncoderLayer
571
+ )
572
+
573
+ if self.use_scan:
574
+ if output_attentions:
575
+ raise ValueError("Cannot use `scan` with `output_attentions` set to True")
576
+
577
+ # nicest behaviour for scan is to let the compiler figure out the correct shapes for the hidden states
578
+ # so we'll just pass an empty tuple as the carry initializer and hold on to the first hidden states for later
579
+ input_hidden_states = hidden_states
580
+ hidden_states = (hidden_states,)
581
+
582
+ hidden_states, all_hidden_states = scan_with_axes(
583
+ FlaxWhisperEncoderCheckpointLayer,
584
+ variable_axes={"params": 0, "cache": 0},
585
+ split_rngs={"params": True, "dropout": True},
586
+ in_axes=(
587
+ nn.broadcast,
588
+ nn.broadcast,
589
+ nn.broadcast,
590
+ nn.broadcast,
591
+ ),
592
+ variable_carry="all_hidden_states",
593
+ length=self.config.encoder_layers,
594
+ )(
595
+ self.config,
596
+ dtype=self.dtype,
597
+ params_dtype=self.params_dtype,
598
+ use_scan=True,
599
+ name="FlaxEncoderScanLayers",
600
+ )(
601
+ hidden_states,
602
+ attention_mask,
603
+ output_attentions,
604
+ deterministic,
605
+ all_hidden_states, # tuple intializer (or None if not using output_hidden_states)
606
+ )
607
+
608
+ # remove the scan dimension
609
+ hidden_states = hidden_states[0]
610
+
611
+ if output_hidden_states:
612
+ # if we're using scan we'll surely be training -> return hidden states as a tensor rather than tuple
613
+ all_hidden_states = jnp.vstack([input_hidden_states[None, ...], all_hidden_states[0]])
614
+
615
+ else:
616
+ for layer_idx in range(self.config.encoder_layers):
617
+ if output_hidden_states:
618
+ all_hidden_states = all_hidden_states + (hidden_states,)
619
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
620
+ dropout_probability = random.uniform(0, 1)
621
+ if not deterministic and (dropout_probability < self.config.encoder_layerdrop): # skip the layer
622
+ layer_outputs = (None, None)
623
+ else:
624
+ layer_outputs = FlaxWhisperEncoderCheckpointLayer(
625
+ self.config,
626
+ dtype=self.dtype,
627
+ params_dtype=self.params_dtype,
628
+ name=str(layer_idx),
629
+ )(
630
+ hidden_states,
631
+ attention_mask,
632
+ output_attentions,
633
+ deterministic,
634
+ )
635
+ hidden_states = layer_outputs[0]
636
+ if output_attentions:
637
+ all_attentions = all_attentions + (layer_outputs[1],)
638
+
639
+ if output_hidden_states:
640
+ all_hidden_states += (hidden_states,)
641
+
642
+ outputs = (hidden_states, all_hidden_states, all_attentions)
643
+
644
+ if not return_dict:
645
+ return tuple(v for v in outputs if v is not None)
646
+
647
+ return FlaxBaseModelOutput(
648
+ last_hidden_state=hidden_states,
649
+ hidden_states=all_hidden_states,
650
+ attentions=all_attentions,
651
+ )
652
+
653
+
654
+ class FlaxWhisperDecoderLayer(nn.Module):
655
+ config: WhisperConfig
656
+ dtype: jnp.dtype = jnp.float32
657
+ params_dtype: jnp.dtype = jnp.float32
658
+ use_scan: bool = False
659
+
660
+ def setup(self) -> None:
661
+ self.embed_dim = self.config.d_model
662
+ self.self_attn = FlaxWhisperAttention(
663
+ config=self.config,
664
+ embed_dim=self.embed_dim,
665
+ num_heads=self.config.decoder_attention_heads,
666
+ dropout=self.config.attention_dropout,
667
+ causal=True,
668
+ dtype=self.dtype,
669
+ params_dtype=self.params_dtype,
670
+ )
671
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
672
+ self.activation_fn = ACT2FN[self.config.activation_function]
673
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
674
+
675
+ self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
676
+ self.encoder_attn = FlaxWhisperAttention(
677
+ config=self.config,
678
+ embed_dim=self.embed_dim,
679
+ num_heads=self.config.decoder_attention_heads,
680
+ dropout=self.config.attention_dropout,
681
+ dtype=self.dtype,
682
+ params_dtype=self.params_dtype,
683
+ )
684
+ self.encoder_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
685
+ self.fc1 = DenseGeneral(
686
+ self.config.decoder_ffn_dim,
687
+ dtype=self.dtype,
688
+ params_dtype=self.params_dtype,
689
+ kernel_axes=("embed", "mlp"),
690
+ )
691
+ self.fc2 = DenseGeneral(
692
+ self.embed_dim,
693
+ dtype=self.dtype,
694
+ params_dtype=self.params_dtype,
695
+ kernel_axes=("mlp", "embed"),
696
+ )
697
+ self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
698
+
699
+ def __call__(
700
+ self,
701
+ hidden_states: jnp.ndarray,
702
+ attention_mask: jnp.ndarray,
703
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
704
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
705
+ init_cache: bool = False,
706
+ output_attentions: bool = True,
707
+ deterministic: bool = True,
708
+ all_hidden_states=None, # only used when `use_scan=True` -> we have to fetch the hidden states from within the layer
709
+ ) -> Tuple[jnp.ndarray]:
710
+ if self.use_scan:
711
+ hidden_states = hidden_states[0]
712
+
713
+ hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
714
+
715
+ residual = hidden_states
716
+
717
+ layer_norm_output = self.self_attn_layer_norm(hidden_states)
718
+ layer_norm_output = with_sharding_constraint(layer_norm_output, ("batch", "length", "embed"))
719
+
720
+ # Self Attention
721
+ self_attn_output, self_attn_weights = self.self_attn(
722
+ hidden_states=layer_norm_output,
723
+ attention_mask=attention_mask,
724
+ init_cache=init_cache,
725
+ )
726
+ self_attn_output = self.dropout_layer(self_attn_output, deterministic=deterministic)
727
+ self_attn_output = residual + self_attn_output
728
+ self_attn_output = with_sharding_constraint(self_attn_output, ("batch", "length", "embed"))
729
+
730
+ # Cross-Attention Block
731
+ cross_attn_weights = None
732
+ if encoder_hidden_states is not None:
733
+ residual = self_attn_output
734
+
735
+ encoder_layer_norm_output = self.encoder_attn_layer_norm(self_attn_output)
736
+ encoder_layer_norm_output = with_sharding_constraint(
737
+ encoder_layer_norm_output, ("batch", "length", "embed")
738
+ )
739
+
740
+ cross_attn_output, cross_attn_weights = self.encoder_attn(
741
+ hidden_states=encoder_layer_norm_output,
742
+ key_value_states=encoder_hidden_states,
743
+ attention_mask=encoder_attention_mask,
744
+ )
745
+ cross_attn_output = self.dropout_layer(cross_attn_output, deterministic=deterministic)
746
+ cross_attn_output = residual + cross_attn_output
747
+ cross_attn_output = with_sharding_constraint(cross_attn_output, ("batch", "length", "embed"))
748
+
749
+ # Fully Connected
750
+ residual = cross_attn_output
751
+
752
+ post_layer_norm = self.final_layer_norm(cross_attn_output)
753
+ post_layer_norm = with_sharding_constraint(post_layer_norm, ("batch", "length", "embed"))
754
+
755
+ fc1_output = self.activation_fn(self.fc1(post_layer_norm))
756
+ fc1_output = self.activation_dropout_layer(fc1_output, deterministic=deterministic)
757
+ fc1_output = with_sharding_constraint(fc1_output, ("batch", "length", "mlp"))
758
+
759
+ hidden_states = self.fc2(fc1_output)
760
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
761
+ hidden_states = residual + hidden_states
762
+ hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
763
+
764
+ outputs = (hidden_states,)
765
+
766
+ if output_attentions:
767
+ outputs += (self_attn_weights, cross_attn_weights)
768
+
769
+ if self.use_scan:
770
+ if all_hidden_states is not None:
771
+ all_hidden_states = all_hidden_states + (hidden_states,)
772
+ outputs = (
773
+ outputs,
774
+ all_hidden_states,
775
+ )
776
+
777
+ return outputs
778
+
779
+
780
+ class FlaxWhisperDecoderLayerCollection(nn.Module):
781
+ config: WhisperConfig
782
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
783
+ params_dtype: jnp.dtype = jnp.float32
784
+ use_scan: bool = False
785
+ gradient_checkpointing: bool = False
786
+
787
+ @nn.compact
788
+ def __call__(
789
+ self,
790
+ hidden_states,
791
+ attention_mask,
792
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
793
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
794
+ deterministic: bool = True,
795
+ init_cache: bool = False,
796
+ output_attentions: bool = False,
797
+ output_hidden_states: bool = False,
798
+ return_dict: bool = True,
799
+ ):
800
+ # decoder layers
801
+ all_hidden_states = () if output_hidden_states else None
802
+ all_self_attns = () if output_attentions else None
803
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
804
+
805
+ FlaxWhisperDecoderCheckpointLayer = (
806
+ remat(
807
+ FlaxWhisperDecoderLayer,
808
+ static_argnums=(4, 5, 6),
809
+ prevent_cse=not self.use_scan,
810
+ )
811
+ if self.gradient_checkpointing
812
+ else FlaxWhisperDecoderLayer
813
+ )
814
+
815
+ if self.use_scan:
816
+ if output_attentions:
817
+ raise ValueError("Cannot use `scan` with `output_attentions` set to True")
818
+
819
+ input_hidden_states = hidden_states
820
+ hidden_states = (hidden_states,)
821
+
822
+ hidden_states, all_hidden_states = scan_with_axes(
823
+ FlaxWhisperDecoderCheckpointLayer,
824
+ variable_axes={"params": 0, "cache": 0},
825
+ split_rngs={"params": True, "dropout": True},
826
+ in_axes=(
827
+ nn.broadcast,
828
+ nn.broadcast,
829
+ nn.broadcast,
830
+ nn.broadcast,
831
+ nn.broadcast,
832
+ nn.broadcast,
833
+ nn.broadcast,
834
+ ),
835
+ variable_carry="all_hidden_states",
836
+ length=self.config.decoder_layers,
837
+ )(
838
+ self.config,
839
+ dtype=self.dtype,
840
+ params_dtype=self.params_dtype,
841
+ use_scan=True,
842
+ name="FlaxDecoderScanLayers",
843
+ )(
844
+ hidden_states,
845
+ attention_mask,
846
+ encoder_hidden_states,
847
+ encoder_attention_mask,
848
+ init_cache,
849
+ output_attentions,
850
+ deterministic,
851
+ all_hidden_states,
852
+ )
853
+ hidden_states = hidden_states[0]
854
+
855
+ if output_hidden_states:
856
+ # if we're using scan we'll surely be training -> return hidden states as a tensor rather than tuple
857
+ all_hidden_states = jnp.vstack([input_hidden_states[None, ...], all_hidden_states[0]])
858
+
859
+ else:
860
+ for layer_idx in range(self.config.decoder_layers):
861
+ if output_hidden_states:
862
+ all_hidden_states += (hidden_states,)
863
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
864
+ dropout_probability = random.uniform(0, 1)
865
+ if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
866
+ layer_outputs = (None, None, None)
867
+ else:
868
+ layer_outputs = FlaxWhisperDecoderCheckpointLayer(
869
+ self.config,
870
+ dtype=self.dtype,
871
+ params_dtype=self.params_dtype,
872
+ name=str(layer_idx),
873
+ )(
874
+ hidden_states,
875
+ attention_mask,
876
+ encoder_hidden_states,
877
+ encoder_attention_mask,
878
+ init_cache,
879
+ output_attentions,
880
+ deterministic,
881
+ )
882
+
883
+ hidden_states = layer_outputs[0]
884
+ if output_attentions:
885
+ all_self_attns += (layer_outputs[1],)
886
+
887
+ if encoder_hidden_states is not None:
888
+ all_cross_attentions += (layer_outputs[2],)
889
+
890
+ # add hidden states from the last decoder layer
891
+ if output_hidden_states:
892
+ all_hidden_states += (hidden_states,)
893
+
894
+ outputs = [
895
+ hidden_states,
896
+ all_hidden_states,
897
+ all_self_attns,
898
+ all_cross_attentions,
899
+ ]
900
+
901
+ if not return_dict:
902
+ return tuple(v for v in outputs if v is not None)
903
+
904
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
905
+ last_hidden_state=hidden_states,
906
+ hidden_states=all_hidden_states,
907
+ attentions=all_self_attns,
908
+ cross_attentions=all_cross_attentions,
909
+ )
910
+
911
+
912
+ class FlaxWhisperEncoder(nn.Module):
913
+ config: WhisperConfig
914
+ dtype: jnp.dtype = jnp.float32
915
+ params_dtype: jnp.dtype = jnp.float32
916
+ use_scan: bool = False
917
+ gradient_checkpointing: bool = False
918
+
919
+ def setup(self) -> None:
920
+ self.conv1 = Conv(
921
+ self.config.d_model,
922
+ kernel_size=(3,),
923
+ padding=1,
924
+ dtype=self.dtype,
925
+ params_dtype=self.params_dtype,
926
+ kernel_axes=("channels", "num_mel", "embed"),
927
+ )
928
+ self.conv2 = Conv(
929
+ self.config.d_model,
930
+ kernel_size=(3,),
931
+ strides=2,
932
+ padding=1,
933
+ dtype=self.dtype,
934
+ params_dtype=self.params_dtype,
935
+ kernel_axes=("channels", "embed", "num_mel"),
936
+ )
937
+
938
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
939
+
940
+ self.layers = FlaxWhisperEncoderLayerCollection(
941
+ self.config,
942
+ dtype=self.dtype,
943
+ params_dtype=self.params_dtype,
944
+ use_scan=self.use_scan,
945
+ gradient_checkpointing=self.gradient_checkpointing,
946
+ )
947
+ self.embed_positions = Embed(
948
+ self.config.max_source_positions,
949
+ self.config.d_model,
950
+ dtype=self.dtype,
951
+ params_dtype=self.params_dtype,
952
+ )
953
+
954
+ self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
955
+
956
+ def __call__(
957
+ self,
958
+ input_features: jnp.ndarray,
959
+ output_attentions: bool = False,
960
+ output_hidden_states: bool = False,
961
+ return_dict: bool = True,
962
+ deterministic: bool = True,
963
+ ) -> Tuple[jnp.ndarray]:
964
+ if input_features.shape[1:] != (
965
+ self.config.num_mel_bins,
966
+ self.config.max_source_positions * 2,
967
+ ):
968
+ raise ValueError(
969
+ "input_features.shape[1:], must be equal to (self.config.num_mel_bins,"
970
+ " self.config.max_source_positions * 2) (got"
971
+ f" {input_features.shape[1:]}, but should be"
972
+ f" ({self.config.num_mel_bins},"
973
+ f" {self.config.max_source_positions * 2}))"
974
+ )
975
+
976
+ input_features = input_features.transpose(0, 2, 1)
977
+ hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False)
978
+ hidden_states = with_sharding_constraint(hidden_states, ("batch", "embed", "num_mel"))
979
+ hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
980
+ hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
981
+
982
+ embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
983
+ # sinusoidal positional embeddings should not be trained
984
+ embed_positions = jax.lax.stop_gradient(embed_positions)
985
+ hidden_states = hidden_states + embed_positions
986
+
987
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
988
+
989
+ outputs = self.layers(
990
+ hidden_states,
991
+ attention_mask=None,
992
+ deterministic=deterministic,
993
+ output_attentions=output_attentions,
994
+ output_hidden_states=output_hidden_states,
995
+ return_dict=return_dict,
996
+ )
997
+
998
+ last_hidden_states = outputs[0]
999
+ last_hidden_states = self.layer_norm(last_hidden_states)
1000
+
1001
+ # update the last element in `hidden_states` after applying `layernorm` above
1002
+ hidden_states = None
1003
+ if output_hidden_states:
1004
+ hidden_states = outputs[1]
1005
+ if self.use_scan:
1006
+ hidden_states = jnp.vstack([hidden_states[:-1], last_hidden_states[None, ...]])
1007
+ else:
1008
+ hidden_states = hidden_states[:-1] + (last_hidden_states,)
1009
+
1010
+ if not return_dict:
1011
+ outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
1012
+ return tuple(v for v in outputs if v is not None)
1013
+
1014
+ return FlaxBaseModelOutput(
1015
+ last_hidden_state=last_hidden_states,
1016
+ hidden_states=hidden_states,
1017
+ attentions=outputs.attentions,
1018
+ )
1019
+
1020
+
1021
+ class FlaxWhisperDecoder(nn.Module):
1022
+ config: WhisperConfig
1023
+ dtype: jnp.dtype = jnp.float32
1024
+ params_dtype: jnp.dtype = jnp.float32
1025
+ use_scan: bool = False
1026
+ gradient_checkpointing: bool = False
1027
+
1028
+ def setup(self) -> None:
1029
+ self.embed_tokens = Embed(
1030
+ self.config.vocab_size,
1031
+ self.config.d_model,
1032
+ dtype=self.dtype,
1033
+ params_dtype=self.params_dtype,
1034
+ )
1035
+ self.embed_positions = Embed(
1036
+ self.config.max_target_positions,
1037
+ self.config.d_model,
1038
+ dtype=self.dtype,
1039
+ params_dtype=self.params_dtype,
1040
+ )
1041
+
1042
+ self.layers = FlaxWhisperDecoderLayerCollection(
1043
+ self.config,
1044
+ dtype=self.dtype,
1045
+ params_dtype=self.params_dtype,
1046
+ use_scan=self.use_scan,
1047
+ gradient_checkpointing=self.gradient_checkpointing,
1048
+ )
1049
+
1050
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
1051
+
1052
+ self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-5, params_dtype=self.params_dtype)
1053
+
1054
+ def __call__(
1055
+ self,
1056
+ input_ids: jnp.ndarray,
1057
+ attention_mask: jnp.ndarray,
1058
+ position_ids: jnp.ndarray,
1059
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1060
+ init_cache: bool = False,
1061
+ output_attentions: bool = False,
1062
+ output_hidden_states: bool = False,
1063
+ return_dict: bool = True,
1064
+ deterministic: bool = True,
1065
+ ) -> Tuple[jnp.ndarray]:
1066
+ input_embeds = self.embed_tokens(input_ids)
1067
+ position_embeds = self.embed_positions(position_ids)
1068
+
1069
+ hidden_states = input_embeds + position_embeds
1070
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
1071
+
1072
+ outputs = self.layers(
1073
+ hidden_states,
1074
+ attention_mask=attention_mask,
1075
+ encoder_hidden_states=encoder_hidden_states,
1076
+ deterministic=deterministic,
1077
+ init_cache=init_cache,
1078
+ output_attentions=output_attentions,
1079
+ output_hidden_states=output_hidden_states,
1080
+ return_dict=return_dict,
1081
+ )
1082
+
1083
+ last_hidden_states = outputs[0]
1084
+ last_hidden_states = self.layer_norm(last_hidden_states)
1085
+
1086
+ # update the last element in `hidden_states` after applying `layernorm` above
1087
+ hidden_states = None
1088
+ if output_hidden_states:
1089
+ hidden_states = outputs[1]
1090
+ if self.use_scan:
1091
+ hidden_states = jnp.vstack([hidden_states[:-1], last_hidden_states[None, ...]])
1092
+ else:
1093
+ hidden_states = hidden_states[:-1] + (last_hidden_states,)
1094
+
1095
+ if not return_dict:
1096
+ outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
1097
+ return tuple(v for v in outputs if v is not None)
1098
+
1099
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
1100
+ last_hidden_state=last_hidden_states,
1101
+ hidden_states=hidden_states,
1102
+ attentions=outputs.attentions,
1103
+ cross_attentions=outputs.cross_attentions,
1104
+ )
1105
+
1106
+
1107
+ class FlaxWhisperModule(nn.Module):
1108
+ config: WhisperConfig
1109
+ dtype: jnp.dtype = jnp.float32
1110
+ params_dtype: jnp.dtype = jnp.float32
1111
+ use_scan: bool = False
1112
+ gradient_checkpointing: bool = False
1113
+
1114
+ def setup(self) -> None:
1115
+ self.encoder = FlaxWhisperEncoder(
1116
+ self.config,
1117
+ dtype=self.dtype,
1118
+ params_dtype=self.params_dtype,
1119
+ use_scan=self.use_scan,
1120
+ gradient_checkpointing=self.gradient_checkpointing,
1121
+ )
1122
+ self.decoder = FlaxWhisperDecoder(
1123
+ self.config,
1124
+ dtype=self.dtype,
1125
+ params_dtype=self.params_dtype,
1126
+ use_scan=self.use_scan,
1127
+ gradient_checkpointing=self.gradient_checkpointing,
1128
+ )
1129
+
1130
+ def __call__(
1131
+ self,
1132
+ input_features: jnp.ndarray,
1133
+ decoder_input_ids: jnp.ndarray,
1134
+ decoder_attention_mask: jnp.ndarray,
1135
+ decoder_position_ids: jnp.ndarray,
1136
+ output_attentions: bool = False,
1137
+ output_hidden_states: bool = False,
1138
+ freeze_encoder: bool = False,
1139
+ return_dict: bool = True,
1140
+ deterministic: bool = True,
1141
+ ):
1142
+ encoder_outputs = self.encoder(
1143
+ input_features,
1144
+ output_attentions=output_attentions,
1145
+ output_hidden_states=output_hidden_states,
1146
+ return_dict=return_dict,
1147
+ deterministic=deterministic,
1148
+ )
1149
+
1150
+ encoder_hidden_states = encoder_outputs[0]
1151
+
1152
+ if freeze_encoder:
1153
+ encoder_hidden_states = jax.lax.stop_gradient(encoder_hidden_states)
1154
+
1155
+ decoder_outputs = self.decoder(
1156
+ input_ids=decoder_input_ids,
1157
+ attention_mask=decoder_attention_mask,
1158
+ position_ids=decoder_position_ids,
1159
+ encoder_hidden_states=encoder_hidden_states,
1160
+ output_attentions=output_attentions,
1161
+ output_hidden_states=output_hidden_states,
1162
+ return_dict=return_dict,
1163
+ deterministic=deterministic,
1164
+ )
1165
+
1166
+ if not return_dict:
1167
+ return decoder_outputs + encoder_outputs
1168
+
1169
+ return FlaxSeq2SeqModelOutput(
1170
+ last_hidden_state=decoder_outputs.last_hidden_state,
1171
+ decoder_hidden_states=decoder_outputs.hidden_states,
1172
+ decoder_attentions=decoder_outputs.attentions,
1173
+ cross_attentions=decoder_outputs.cross_attentions,
1174
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1175
+ encoder_hidden_states=encoder_outputs.hidden_states,
1176
+ encoder_attentions=encoder_outputs.attentions,
1177
+ )
1178
+
1179
+ def _get_encoder_module(self):
1180
+ return self.encoder
1181
+
1182
+ def _get_decoder_module(self):
1183
+ return self.decoder
1184
+
1185
+
1186
+ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
1187
+ config_class = WhisperConfig
1188
+ base_model_prefix: str = "model"
1189
+ main_input_name = "input_features"
1190
+ module_class: nn.Module = None
1191
+
1192
+ def __init__(
1193
+ self,
1194
+ config: WhisperConfig,
1195
+ input_shape: Tuple[int, int, int] = None,
1196
+ seed: int = 0,
1197
+ dtype: jnp.dtype = jnp.float32,
1198
+ params_dtype: jnp.dtype = jnp.float32,
1199
+ _do_init: bool = True,
1200
+ # Can only use_scan=True in init if loading scanned weights -> need to handle use_scan=True and unrolled weights
1201
+ use_scan: bool = False,
1202
+ gradient_checkpointing: bool = False,
1203
+ **kwargs,
1204
+ ):
1205
+ self.use_scan = use_scan
1206
+ self.gradient_checkpointing = gradient_checkpointing
1207
+
1208
+ module = self.module_class(
1209
+ config=config,
1210
+ dtype=dtype,
1211
+ params_dtype=params_dtype,
1212
+ use_scan=use_scan,
1213
+ gradient_checkpointing=gradient_checkpointing,
1214
+ **kwargs,
1215
+ )
1216
+
1217
+ if input_shape is None:
1218
+ input_shape = (1, config.num_mel_bins, 2 * config.max_source_positions)
1219
+
1220
+ super().__init__(
1221
+ config,
1222
+ module,
1223
+ input_shape=input_shape,
1224
+ seed=seed,
1225
+ dtype=dtype,
1226
+ _do_init=_do_init,
1227
+ )
1228
+
1229
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
1230
+ # init input tensors
1231
+ input_features = jnp.zeros(input_shape, dtype="f4")
1232
+ input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
1233
+
1234
+ decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
1235
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
1236
+
1237
+ batch_size, sequence_length = decoder_input_ids.shape
1238
+ decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
1239
+
1240
+ params_rng, dropout_rng = jax.random.split(rng)
1241
+ rngs = {"params": params_rng, "dropout": dropout_rng}
1242
+
1243
+ random_params = self.module.init(
1244
+ rngs,
1245
+ input_features=input_features,
1246
+ decoder_input_ids=decoder_input_ids,
1247
+ decoder_attention_mask=decoder_attention_mask,
1248
+ decoder_position_ids=decoder_position_ids,
1249
+ )["params"]
1250
+
1251
+ if params is not None:
1252
+ random_params = flatten_dict(unfreeze(random_params))
1253
+ params = flatten_dict(unfreeze(params))
1254
+ for missing_key in self._missing_keys:
1255
+ params[missing_key] = random_params[missing_key]
1256
+ self._missing_keys = set()
1257
+ return freeze(unflatten_dict(params))
1258
+ else:
1259
+ return random_params
1260
+
1261
+ def enable_gradient_checkpointing(self):
1262
+ self.gradient_checkpointing = True
1263
+ self._module = self.module_class(
1264
+ config=self.config,
1265
+ dtype=self.dtype,
1266
+ use_scan=self.use_scan,
1267
+ gradient_checkpointing=self.gradient_checkpointing,
1268
+ )
1269
+
1270
+ def enable_scan(self):
1271
+ self.use_scan = True
1272
+ self._module = self.module_class(
1273
+ config=self.config,
1274
+ dtype=self.dtype,
1275
+ use_scan=self.use_scan,
1276
+ gradient_checkpointing=self.gradient_checkpointing,
1277
+ )
1278
+ init_fn = partial(self.init_weights, input_shape=self.input_shape)
1279
+ params_shape_tree = jax.eval_shape(init_fn, self.key)
1280
+
1281
+ # get the shape of the parameters
1282
+ self._params_shape_tree = params_shape_tree
1283
+
1284
+ # save required_params as set
1285
+ self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
1286
+
1287
+ # initialize the parameters
1288
+ if self._is_initialized:
1289
+ self.params = self.convert_unroll_to_scan(self.params)
1290
+
1291
+ def disable_scan(self):
1292
+ self.use_scan = False
1293
+ self._module = self.module_class(
1294
+ config=self.config,
1295
+ dtype=self.dtype,
1296
+ use_scan=self.use_scan,
1297
+ gradient_checkpointing=self.gradient_checkpointing,
1298
+ )
1299
+ init_fn = partial(self.init_weights, input_shape=self.input_shape)
1300
+ params_shape_tree = jax.eval_shape(init_fn, self.key)
1301
+
1302
+ # get the shape of the parameters
1303
+ self._params_shape_tree = params_shape_tree
1304
+
1305
+ # save required_params as set
1306
+ self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
1307
+
1308
+ # initialize the parameters
1309
+ if self._is_initialized:
1310
+ self.params = self.convert_scan_to_unroll(self.params)
1311
+
1312
+ def convert_unroll_to_scan(self, params: Union[Dict, FrozenDict]):
1313
+ r"""
1314
+ Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used
1315
+ to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not
1316
+ convert the `params` in place.
1317
+
1318
+ To illustrate the workings of this method, take the Flax BERT model. The unrolled structure for the query
1319
+ projection params is as follows:
1320
+ ('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
1321
+ 'q_proj') ... ('bert', 'encoder', 'layer', '23', 'self_attn', 'q_proj')
1322
+ This method takes each of the `q_proj` matrices for layers (0, ..., 23) and stacks them into a single 'super'
1323
+ matrix, giving a *single* block of weights for all 24 layers compatible with the scanned model:
1324
+ ('bert', 'encoder', 'layer', 'ScanLayers', 'self_attn', 'q_proj')
1325
+
1326
+ When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
1327
+ _do_init=False, it will have to be called explicitly (see example below).
1328
+
1329
+ Arguments:
1330
+ params (`Union[Dict, FrozenDict]`):
1331
+ A `PyTree` of model parameters.
1332
+
1333
+ Examples:
1334
+
1335
+ ```python
1336
+ >>> from distil_whisper import FlaxWhisperForConditionalGeneration
1337
+
1338
+ >>> # Download model and configuration from huggingface.co
1339
+ >>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
1340
+ >>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
1341
+ >>> # we'll first convert to scan format and then back to unrolled
1342
+ >>> model.enable_scan()
1343
+ >>> params = model.convert_unroll_to_scan(params)
1344
+ >>> # now convert back to unrolled
1345
+ >>> model.disable_scan()
1346
+ >>> params = model.convert_scan_to_unroll(params)
1347
+ ```"""
1348
+ if isinstance(params, FrozenDict):
1349
+ params = unfreeze(params)
1350
+
1351
+ params = flatten_dict(params, sep="/")
1352
+ keys = list(params.keys())
1353
+
1354
+ for k in keys:
1355
+ # Identify all "unrolled" layers formed as part of the FlaxBertLayerCollection
1356
+ # These params contain the identifier `layer` in their key
1357
+ if "layers/0" in k:
1358
+ if "decoder" in k:
1359
+ block_prefix = "Decoder"
1360
+ num_hidden_layers = self.config.decoder_layers
1361
+ else:
1362
+ block_prefix = "Encoder"
1363
+ num_hidden_layers = self.config.encoder_layers
1364
+
1365
+ # Squash the keys for the N unrolled layers into one single key:
1366
+ # (layer/0, ..., layer/N) -> layer/FlaxScanLayers
1367
+ scan_key = k.replace("0", f"Flax{block_prefix}ScanLayers")
1368
+ stacked_params = []
1369
+
1370
+ # Iterate over the unrolled layers (1,...,N)
1371
+ for i in range(num_hidden_layers):
1372
+ # Stack the params for the N layers into one super block
1373
+ # and remove the unrolled layer params on the fly
1374
+ # -> no memory overhead for conversion!
1375
+ unrolled_layer = params.pop(k.replace("0", str(i)))
1376
+ stacked_params.append(unrolled_layer)
1377
+
1378
+ params[scan_key] = jnp.stack(stacked_params)
1379
+
1380
+ # Finally, unflatten the dict to restore the nested pytree structure
1381
+ params = unflatten_dict(params, sep="/")
1382
+ return params
1383
+
1384
+ def convert_scan_to_unroll(self, params: Union[Dict, FrozenDict]):
1385
+ r"""
1386
+ Convert a `PyTree` of scanned model parameters to an unrolled stack of model parameters. This method can be
1387
+ used to explicitly convert the model parameters to unrolled format. This returns a new `params` tree and does
1388
+ not convert the `params` in place.
1389
+
1390
+ To illustrate the workings of this method, take the Flax BERT model. The scanned structure for the query
1391
+ projection (`q_proj`) params is a single, stacked matrix of parameters over all N layers:
1392
+ ('bert', 'encoder', 'layer', 'FlaxScanLayers', 'self_attn', 'q_proj')
1393
+
1394
+ This method slices each layer of the `q_proj` scanned matrix into single, standalone layers, and replaces the
1395
+ scanned matrix of parameteres on the fly:
1396
+ ('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
1397
+ 'q_proj') ... ('bert', 'encoder', 'layer', 'N', 'self_attn', 'q_proj')
1398
+
1399
+ When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
1400
+ _do_init=False, it will have to be called explicitly (see example below).
1401
+
1402
+ Arguments:
1403
+ params (`Union[Dict, FrozenDict]`):
1404
+ A `PyTree` of model parameters.
1405
+
1406
+ Examples:
1407
+
1408
+ ```python
1409
+ >>> from distil_whisper import FlaxWhisperForConditionalGeneration
1410
+
1411
+ >>> # Download model and configuration from huggingface.co
1412
+ >>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
1413
+ >>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
1414
+ >>> # we'll first convert to scan format and then back to unrolled
1415
+ >>> model.enable_scan()
1416
+ >>> params = model.convert_unroll_to_scan(params)
1417
+ >>> # now convert back to unrolled
1418
+ >>> model.disable_scan()
1419
+ >>> params = model.convert_scan_to_unroll(params)
1420
+ ```"""
1421
+
1422
+ if isinstance(params, FrozenDict):
1423
+ params = unfreeze(params)
1424
+
1425
+ params = flatten_dict(params, sep="/")
1426
+ keys = list(params.keys())
1427
+
1428
+ for k in keys:
1429
+ # Identify all "scan" layers formed as part of the FlaxBertLayerCollection
1430
+ # These params contain the identifier `FlaxScanLayers` in their key
1431
+ if "FlaxEncoderScanLayers" in k:
1432
+ # Remove the scan layer from the PyTree of params
1433
+ scan_layer = params.pop(k)
1434
+
1435
+ # Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
1436
+ # layer/FlaxScanLayers -> (layer/0, ..., layer/N)
1437
+ for i in range(self.config.encoder_layers):
1438
+ # Unstack the params for the i-th scan layer to unrolled
1439
+ # and remove corresponding scan params on the fly
1440
+ # -> no memory overhead for conversion!
1441
+ unrolled_key = k.replace("FlaxEncoderScanLayers", str(i))
1442
+ params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]
1443
+
1444
+ elif "FlaxDecoderScanLayers" in k:
1445
+ # Remove the scan layer from the PyTree of params
1446
+ scan_layer = params.pop(k)
1447
+
1448
+ # Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
1449
+ # layer/FlaxScanLayers -> (layer/0, ..., layer/N)
1450
+ for i in range(self.config.decoder_layers):
1451
+ # Unstack the params for the i-th scan layer to unrolled
1452
+ # and remove corresponding scan params on the fly
1453
+ # -> no memory overhead for conversion!
1454
+ unrolled_key = k.replace("FlaxDecoderScanLayers", str(i))
1455
+ params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]
1456
+
1457
+ params = unflatten_dict(params, sep="/")
1458
+ return params
1459
+
1460
+ # Copied from transformers.models.whisper.modeling_flax_whisper.FlaxWhisperPreTrainedModel.init_cache
1461
+ def init_cache(self, batch_size, max_length, encoder_outputs):
1462
+ r"""
1463
+ Args:
1464
+ batch_size (`int`):
1465
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
1466
+ max_length (`int`):
1467
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
1468
+ cache.
1469
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
1470
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
1471
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
1472
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
1473
+ cross-attention of the decoder.
1474
+ """
1475
+ # init input variables to retrieve cache
1476
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
1477
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
1478
+ decoder_position_ids = jnp.broadcast_to(
1479
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]),
1480
+ decoder_input_ids.shape,
1481
+ )
1482
+
1483
+ def _decoder_forward(
1484
+ module,
1485
+ decoder_input_ids,
1486
+ decoder_attention_mask,
1487
+ decoder_position_ids,
1488
+ **kwargs,
1489
+ ):
1490
+ decoder_module = module._get_decoder_module()
1491
+ return decoder_module(
1492
+ decoder_input_ids,
1493
+ decoder_attention_mask,
1494
+ decoder_position_ids,
1495
+ **kwargs,
1496
+ )
1497
+
1498
+ init_variables = self.module.init(
1499
+ jax.random.PRNGKey(0),
1500
+ decoder_input_ids=decoder_input_ids,
1501
+ decoder_attention_mask=decoder_attention_mask,
1502
+ decoder_position_ids=decoder_position_ids,
1503
+ encoder_hidden_states=encoder_outputs[0],
1504
+ init_cache=True,
1505
+ method=_decoder_forward, # we only need to call the decoder to init the cache
1506
+ )
1507
+ return unfreeze(init_variables["cache"])
1508
+
1509
+ @add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING)
1510
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig)
1511
+ def encode(
1512
+ self,
1513
+ input_features: jnp.ndarray,
1514
+ attention_mask: Optional[jnp.ndarray] = None,
1515
+ output_attentions: Optional[bool] = None,
1516
+ output_hidden_states: Optional[bool] = None,
1517
+ return_dict: Optional[bool] = None,
1518
+ train: bool = False,
1519
+ params: dict = None,
1520
+ dropout_rng: PRNGKey = None,
1521
+ **kwargs,
1522
+ ):
1523
+ r"""
1524
+ Returns:
1525
+
1526
+ Example:
1527
+
1528
+ ```python
1529
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
1530
+ >>> from datasets import load_dataset
1531
+
1532
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1533
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
1534
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1535
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
1536
+ >>> input_features = inputs.input_features
1537
+ >>> encoder_outputs = model.encode(input_features=input_features)
1538
+ ```"""
1539
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1540
+ output_hidden_states = (
1541
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1542
+ )
1543
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1544
+
1545
+ # Handle any PRNG if needed
1546
+ rngs = {}
1547
+ if dropout_rng is not None:
1548
+ rngs["dropout"] = dropout_rng
1549
+
1550
+ def _encoder_forward(module, input_features, **kwargs):
1551
+ encode_module = module._get_encoder_module()
1552
+ return encode_module(input_features, **kwargs)
1553
+
1554
+ return self.module.apply(
1555
+ {"params": params or self.params},
1556
+ input_features=jnp.array(input_features, dtype="f4"),
1557
+ output_attentions=output_attentions,
1558
+ output_hidden_states=output_hidden_states,
1559
+ return_dict=return_dict,
1560
+ deterministic=not train,
1561
+ rngs=rngs,
1562
+ method=_encoder_forward,
1563
+ )
1564
+
1565
+ @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
1566
+ @replace_return_docstrings(
1567
+ output_type=FlaxBaseModelOutputWithPastAndCrossAttentions,
1568
+ config_class=WhisperConfig,
1569
+ )
1570
+ def decode(
1571
+ self,
1572
+ decoder_input_ids,
1573
+ encoder_outputs,
1574
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1575
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1576
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1577
+ past_key_values: dict = None,
1578
+ output_attentions: Optional[bool] = None,
1579
+ output_hidden_states: Optional[bool] = None,
1580
+ return_dict: Optional[bool] = None,
1581
+ train: bool = False,
1582
+ params: dict = None,
1583
+ dropout_rng: PRNGKey = None,
1584
+ ):
1585
+ r"""
1586
+ Returns:
1587
+
1588
+ Example:
1589
+
1590
+ ```python
1591
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
1592
+ >>> from datasets import load_dataset
1593
+
1594
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1595
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
1596
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1597
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
1598
+ >>> input_features = inputs.input_features
1599
+ >>> encoder_outputs = model.encode(input_features=input_features)
1600
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
1601
+
1602
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
1603
+
1604
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
1605
+ >>> last_decoder_hidden_states = outputs.last_hidden_state
1606
+ ```"""
1607
+
1608
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1609
+ output_hidden_states = (
1610
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1611
+ )
1612
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1613
+
1614
+ encoder_hidden_states = encoder_outputs[0]
1615
+
1616
+ batch_size, sequence_length = decoder_input_ids.shape
1617
+ if decoder_position_ids is None:
1618
+ if past_key_values is not None:
1619
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
1620
+
1621
+ if decoder_attention_mask is not None:
1622
+ decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
1623
+ else:
1624
+ decoder_position_ids = jnp.broadcast_to(
1625
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1626
+ )
1627
+
1628
+ if decoder_attention_mask is None:
1629
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
1630
+
1631
+ # Handle any PRNG if needed
1632
+ rngs = {}
1633
+ if dropout_rng is not None:
1634
+ rngs["dropout"] = dropout_rng
1635
+
1636
+ inputs = {"params": params or self.params}
1637
+
1638
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1639
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1640
+ # it can be changed by FlaxWhisperAttention module
1641
+ if past_key_values:
1642
+ inputs["cache"] = past_key_values
1643
+ mutable = ["cache"]
1644
+ else:
1645
+ mutable = False
1646
+
1647
+ def _decoder_forward(
1648
+ module,
1649
+ decoder_input_ids,
1650
+ decoder_attention_mask,
1651
+ decoder_position_ids,
1652
+ **kwargs,
1653
+ ):
1654
+ decoder_module = module._get_decoder_module()
1655
+ return decoder_module(
1656
+ input_ids=decoder_input_ids,
1657
+ attention_mask=decoder_attention_mask,
1658
+ position_ids=decoder_position_ids,
1659
+ **kwargs,
1660
+ )
1661
+
1662
+ outputs = self.module.apply(
1663
+ inputs,
1664
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1665
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1666
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1667
+ encoder_hidden_states=encoder_hidden_states,
1668
+ output_attentions=output_attentions,
1669
+ output_hidden_states=output_hidden_states,
1670
+ return_dict=return_dict,
1671
+ deterministic=not train,
1672
+ rngs=rngs,
1673
+ mutable=mutable,
1674
+ method=_decoder_forward,
1675
+ )
1676
+
1677
+ # add updated cache to model output
1678
+ if past_key_values is not None and return_dict:
1679
+ outputs, past = outputs
1680
+ outputs["past_key_values"] = unfreeze(past["cache"])
1681
+ return outputs
1682
+ elif past_key_values is not None and not return_dict:
1683
+ outputs, past = outputs
1684
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1685
+
1686
+ return outputs
1687
+
1688
+ @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
1689
+ def __call__(
1690
+ self,
1691
+ input_features: jnp.ndarray,
1692
+ decoder_input_ids: jnp.ndarray,
1693
+ attention_mask: Optional[jnp.ndarray] = None,
1694
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1695
+ position_ids: Optional[jnp.ndarray] = None,
1696
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1697
+ output_attentions: Optional[bool] = None,
1698
+ output_hidden_states: Optional[bool] = None,
1699
+ freeze_encoder: Optional[bool] = None,
1700
+ return_dict: Optional[bool] = None,
1701
+ train: bool = False,
1702
+ params: dict = None,
1703
+ dropout_rng: PRNGKey = None,
1704
+ ):
1705
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1706
+ output_hidden_states = (
1707
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1708
+ )
1709
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1710
+
1711
+ # prepare decoder inputs
1712
+ if decoder_position_ids is None:
1713
+ if decoder_attention_mask is not None:
1714
+ decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
1715
+ else:
1716
+ batch_size, sequence_length = decoder_input_ids.shape
1717
+ decoder_position_ids = jnp.broadcast_to(
1718
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1719
+ )
1720
+ if decoder_attention_mask is None:
1721
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
1722
+
1723
+ # Handle any PRNG if needed
1724
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
1725
+
1726
+ return self.module.apply(
1727
+ {"params": params or self.params},
1728
+ input_features=jnp.array(input_features, dtype="f4"),
1729
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1730
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1731
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1732
+ output_attentions=output_attentions,
1733
+ output_hidden_states=output_hidden_states,
1734
+ freeze_encoder=freeze_encoder,
1735
+ return_dict=return_dict,
1736
+ deterministic=not train,
1737
+ rngs=rngs,
1738
+ )
1739
+
1740
+
1741
+ @add_start_docstrings(
1742
+ ("The bare Whisper Model transformer outputting raw hidden-states without any specific head on top."),
1743
+ WHISPER_START_DOCSTRING,
1744
+ )
1745
+ class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
1746
+ config: WhisperConfig
1747
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1748
+ params_dtype: jnp.dtype = jnp.float32
1749
+ module_class = FlaxWhisperModule
1750
+
1751
+
1752
+ append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
1753
+
1754
+
1755
+ class FlaxWhisperForConditionalGenerationModule(nn.Module):
1756
+ config: WhisperConfig
1757
+ dtype: jnp.dtype = jnp.float32
1758
+ params_dtype: jnp.dtype = jnp.float32
1759
+ use_scan: bool = False
1760
+ gradient_checkpointing: bool = False
1761
+
1762
+ def setup(self) -> None:
1763
+ self.model = FlaxWhisperModule(
1764
+ config=self.config,
1765
+ dtype=self.dtype,
1766
+ params_dtype=self.params_dtype,
1767
+ use_scan=self.use_scan,
1768
+ gradient_checkpointing=self.gradient_checkpointing,
1769
+ )
1770
+ self.lm_head = DenseGeneral(
1771
+ self.config.vocab_size,
1772
+ use_bias=False,
1773
+ dtype=self.dtype,
1774
+ params_dtype=self.params_dtype,
1775
+ kernel_axes=("embed", "vocab"),
1776
+ )
1777
+
1778
+ def _get_encoder_module(self):
1779
+ return self.model.encoder
1780
+
1781
+ def _get_decoder_module(self):
1782
+ return self.model.decoder
1783
+
1784
+ def __call__(
1785
+ self,
1786
+ input_features,
1787
+ decoder_input_ids,
1788
+ decoder_attention_mask: jnp.ndarray = None,
1789
+ decoder_position_ids: jnp.ndarray = None,
1790
+ position_ids: jnp.ndarray = None,
1791
+ attention_mask: jnp.ndarray = None,
1792
+ output_attentions: bool = False,
1793
+ output_hidden_states: bool = False,
1794
+ freeze_encoder: bool = False,
1795
+ return_dict: bool = True,
1796
+ deterministic: bool = True,
1797
+ ):
1798
+ outputs = self.model(
1799
+ input_features=input_features,
1800
+ decoder_input_ids=decoder_input_ids,
1801
+ decoder_attention_mask=decoder_attention_mask,
1802
+ decoder_position_ids=decoder_position_ids,
1803
+ output_attentions=output_attentions,
1804
+ output_hidden_states=output_hidden_states,
1805
+ freeze_encoder=freeze_encoder,
1806
+ return_dict=return_dict,
1807
+ deterministic=deterministic,
1808
+ )
1809
+
1810
+ hidden_states = outputs[0]
1811
+
1812
+ if self.config.tie_word_embeddings:
1813
+ shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"]
1814
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
1815
+ else:
1816
+ lm_logits = self.lm_head(hidden_states)
1817
+
1818
+ if not return_dict:
1819
+ output = (lm_logits,) + outputs[1:]
1820
+ return output
1821
+
1822
+ return FlaxSeq2SeqLMOutput(
1823
+ logits=lm_logits,
1824
+ decoder_hidden_states=outputs.decoder_hidden_states,
1825
+ decoder_attentions=outputs.decoder_attentions,
1826
+ cross_attentions=outputs.cross_attentions,
1827
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1828
+ encoder_hidden_states=outputs.encoder_hidden_states,
1829
+ encoder_attentions=outputs.encoder_attentions,
1830
+ )
1831
+
1832
+
1833
+ @add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING)
1834
+ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
1835
+ module_class = FlaxWhisperForConditionalGenerationModule
1836
+
1837
+ @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
1838
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig)
1839
+ def decode(
1840
+ self,
1841
+ decoder_input_ids,
1842
+ encoder_outputs,
1843
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1844
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1845
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1846
+ past_key_values: dict = None,
1847
+ output_attentions: Optional[bool] = None,
1848
+ output_hidden_states: Optional[bool] = None,
1849
+ return_dict: Optional[bool] = None,
1850
+ train: bool = False,
1851
+ params: dict = None,
1852
+ dropout_rng: PRNGKey = None,
1853
+ ):
1854
+ r"""
1855
+ Returns:
1856
+
1857
+ Example:
1858
+
1859
+ ```python
1860
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
1861
+ >>> from datasets import load_dataset
1862
+
1863
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1864
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
1865
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1866
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
1867
+ >>> input_features = inputs.input_features
1868
+ >>> encoder_outputs = model.encode(input_features=input_features)
1869
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
1870
+
1871
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
1872
+
1873
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
1874
+ >>> last_decoder_hidden_states = outputs.last_hidden_state
1875
+ ```"""
1876
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1877
+ output_hidden_states = (
1878
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1879
+ )
1880
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1881
+
1882
+ encoder_hidden_states = encoder_outputs[0]
1883
+
1884
+ batch_size, sequence_length = decoder_input_ids.shape
1885
+ if decoder_position_ids is None:
1886
+ if past_key_values is not None:
1887
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
1888
+
1889
+ if decoder_attention_mask is not None:
1890
+ decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
1891
+ else:
1892
+ decoder_position_ids = jnp.broadcast_to(
1893
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1894
+ )
1895
+ if decoder_attention_mask is None:
1896
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4")
1897
+
1898
+ # Handle any PRNG if needed
1899
+ rngs = {}
1900
+ if dropout_rng is not None:
1901
+ rngs["dropout"] = dropout_rng
1902
+
1903
+ inputs = {"params": params or self.params}
1904
+
1905
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1906
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1907
+ # it can be changed by FlaxWhisperAttention module
1908
+ if past_key_values:
1909
+ inputs["cache"] = past_key_values
1910
+ mutable = ["cache"]
1911
+ else:
1912
+ mutable = False
1913
+
1914
+ def _decoder_forward(
1915
+ module,
1916
+ decoder_input_ids,
1917
+ decoder_attention_mask,
1918
+ decoder_position_ids,
1919
+ **kwargs,
1920
+ ):
1921
+ decoder_module = module._get_decoder_module()
1922
+ outputs = decoder_module(
1923
+ input_ids=decoder_input_ids,
1924
+ attention_mask=decoder_attention_mask,
1925
+ position_ids=decoder_position_ids,
1926
+ **kwargs,
1927
+ )
1928
+ hidden_states = outputs[0]
1929
+
1930
+ if self.config.tie_word_embeddings:
1931
+ shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"]
1932
+ lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
1933
+ else:
1934
+ lm_logits = module.lm_head(hidden_states)
1935
+
1936
+ return lm_logits, outputs
1937
+
1938
+ outputs = self.module.apply(
1939
+ inputs,
1940
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1941
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1942
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1943
+ encoder_hidden_states=encoder_hidden_states,
1944
+ output_attentions=output_attentions,
1945
+ output_hidden_states=output_hidden_states,
1946
+ return_dict=return_dict,
1947
+ deterministic=not train,
1948
+ rngs=rngs,
1949
+ mutable=mutable,
1950
+ method=_decoder_forward,
1951
+ )
1952
+
1953
+ if past_key_values is None:
1954
+ lm_logits, decoder_outputs = outputs
1955
+ else:
1956
+ (lm_logits, decoder_outputs), past = outputs
1957
+
1958
+ if return_dict:
1959
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
1960
+ logits=lm_logits,
1961
+ hidden_states=decoder_outputs.hidden_states,
1962
+ attentions=decoder_outputs.attentions,
1963
+ cross_attentions=decoder_outputs.cross_attentions,
1964
+ )
1965
+ else:
1966
+ outputs = (lm_logits,) + decoder_outputs[1:]
1967
+
1968
+ # add updated cache to model output
1969
+ if past_key_values is not None and return_dict:
1970
+ outputs["past_key_values"] = unfreeze(past["cache"])
1971
+ return outputs
1972
+ elif past_key_values is not None and not return_dict:
1973
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1974
+
1975
+ return outputs
1976
+
1977
+ def generate(
1978
+ self,
1979
+ input_features,
1980
+ generation_config=None,
1981
+ logits_processor=None,
1982
+ return_timestamps=None,
1983
+ task=None,
1984
+ language=None,
1985
+ is_multilingual=None,
1986
+ **kwargs,
1987
+ ):
1988
+ if generation_config is None:
1989
+ generation_config = self.generation_config
1990
+
1991
+ if return_timestamps is not None:
1992
+ generation_config.return_timestamps = return_timestamps
1993
+
1994
+ if task is not None:
1995
+ generation_config.task = task
1996
+
1997
+ if is_multilingual is not None:
1998
+ generation_config.is_multilingual = is_multilingual
1999
+
2000
+ if language is not None:
2001
+ generation_config.language = language
2002
+
2003
+ if kwargs is not None and "decoder_input_ids" in kwargs:
2004
+ decoder_input_length = len(kwargs["decoder_input_ids"])
2005
+ else:
2006
+ decoder_input_length = 1
2007
+
2008
+ forced_decoder_ids = []
2009
+
2010
+ if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
2011
+ if hasattr(generation_config, "language"):
2012
+ forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
2013
+ else:
2014
+ forced_decoder_ids.append((1, None))
2015
+
2016
+ if hasattr(generation_config, "task"):
2017
+ forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
2018
+ else:
2019
+ forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
2020
+
2021
+ if (
2022
+ hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
2023
+ ) or return_timestamps:
2024
+ logits_processor = [
2025
+ FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length)
2026
+ ]
2027
+ else:
2028
+ if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
2029
+ idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
2030
+ forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
2031
+
2032
+ if len(forced_decoder_ids) > 0:
2033
+ generation_config.forced_decoder_ids = forced_decoder_ids
2034
+
2035
+ return super().generate(
2036
+ input_features,
2037
+ generation_config,
2038
+ logits_processor=logits_processor,
2039
+ **kwargs,
2040
+ )
2041
+
2042
+ def pipeline_generate(
2043
+ self,
2044
+ input_features,
2045
+ forced_decoder_ids,
2046
+ return_timestamps=False,
2047
+ generation_config=None,
2048
+ **kwargs,
2049
+ ):
2050
+ if generation_config is None:
2051
+ generation_config = self.generation_config
2052
+
2053
+ # override the generation config forced decoder ids in preference of the ones we have set
2054
+ generation_config.forced_decoder_ids = None
2055
+
2056
+ logits_processor = FlaxLogitsProcessorList()
2057
+ logits_processor.append(FlaxStaticForceTokensLogitsProcessor(forced_decoder_ids))
2058
+
2059
+ if hasattr(generation_config, "return_timestamps") and return_timestamps:
2060
+ logits_processor.append(FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, 1))
2061
+
2062
+ return super().generate(
2063
+ input_features,
2064
+ generation_config,
2065
+ logits_processor=logits_processor,
2066
+ **kwargs,
2067
+ )
2068
+
2069
+ def prepare_inputs_for_generation(
2070
+ self,
2071
+ decoder_input_ids,
2072
+ max_length,
2073
+ attention_mask: Optional[jax.Array] = None,
2074
+ decoder_attention_mask: Optional[jax.Array] = None,
2075
+ encoder_outputs=None,
2076
+ **kwargs,
2077
+ ):
2078
+ # initializing the cache
2079
+ batch_size, seq_length = decoder_input_ids.shape
2080
+
2081
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
2082
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
2083
+ # But since the decoder uses a causal mask, those positions are masked anyways.
2084
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
2085
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
2086
+ if decoder_attention_mask is not None:
2087
+ position_ids = decoder_attention_mask.cumsum(-1) - 1
2088
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
2089
+ else:
2090
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
2091
+
2092
+ return {
2093
+ "past_key_values": past_key_values,
2094
+ "encoder_outputs": encoder_outputs,
2095
+ "encoder_attention_mask": attention_mask,
2096
+ "decoder_attention_mask": extended_attention_mask,
2097
+ "decoder_position_ids": position_ids,
2098
+ }
2099
+
2100
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
2101
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
2102
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
2103
+ return model_kwargs
2104
+
2105
+
2106
+ FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r"""
2107
+ Returns:
2108
+
2109
+ Transcription example:
2110
+
2111
+ ```python
2112
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
2113
+ >>> from datasets import load_dataset
2114
+
2115
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
2116
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
2117
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
2118
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
2119
+ >>> input_features = inputs.input_features
2120
+ >>> generated_ids = model.generate(input_ids=input_features)
2121
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
2122
+ >>> transcription
2123
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
2124
+ ```
2125
+ """
2126
+
2127
+ overwrite_call_docstring(
2128
+ FlaxWhisperForConditionalGeneration,
2129
+ WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING,
2130
+ )
2131
+ append_replace_return_docstrings(
2132
+ FlaxWhisperForConditionalGeneration,
2133
+ output_type=FlaxSeq2SeqLMOutput,
2134
+ config_class=_CONFIG_FOR_DOC,
2135
+ )
flax/distil_whisper/partitioner.py ADDED
@@ -0,0 +1,965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities for partitioning."""
16
+
17
+ import abc
18
+ import collections
19
+ import dataclasses
20
+ import typing
21
+ from typing import Any, Callable, Optional, Sequence, Tuple, Union
22
+
23
+ import cached_property
24
+ import jax
25
+ import numpy as np
26
+ from absl import logging
27
+ from flax import traverse_util
28
+ from flax.linen import partitioning as flax_partitioning
29
+ from jax import numpy as jnp
30
+ from jax import random
31
+ from jax.experimental import multihost_utils
32
+ from jax.experimental.mesh_utils import create_hybrid_device_mesh
33
+ from jax.experimental.pjit import pjit as jax_pjit
34
+ from jax.sharding import Mesh, PartitionSpec
35
+
36
+
37
+ JaxDevice = Any
38
+ TpuMesh = Tuple[int, int, int, int] # (x, y, z, num_cores).
39
+ OtherMesh = Tuple[int, int]
40
+ HardwareMesh = Union[TpuMesh, OtherMesh]
41
+ PyTreeDef = type(jax.tree_util.tree_structure(None))
42
+ TrainState = Any
43
+ LogicalAxisRules = Sequence[Tuple[str, Optional[str]]]
44
+
45
+ if typing.TYPE_CHECKING: # See b/163639353
46
+ cached_property = property # pylint: disable=invalid-name
47
+ else:
48
+ cached_property = cached_property.cached_property
49
+
50
+
51
+ class AxisNames(tuple):
52
+ """Tuple of strings specifying name for each axis.
53
+
54
+ We create a separate class for this so JAX's pytree utilities can distinguish
55
+ it from a tuple that should be treated as a pytree, instead treating it as a
56
+ leaf.
57
+ """
58
+
59
+ def __new__(cls, *names):
60
+ return tuple.__new__(AxisNames, names)
61
+
62
+ def __repr__(self):
63
+ return "AxisNames%s" % tuple.__repr__(self)
64
+
65
+
66
+ # pjit wrappers for cpu fallback.
67
+ # ----------------------------------------------------------------------------
68
+ # TODO(levskaya): This function is now no different than jax_pjit, but callers
69
+ # currently depend on `backend` argument
70
+ def pjit(
71
+ fun: Callable, # pylint: disable=g-bare-generic
72
+ in_axis_resources,
73
+ out_axis_resources,
74
+ static_argnums: Union[int, Sequence[int]] = (),
75
+ donate_argnums: Union[int, Sequence[int]] = (),
76
+ backend: Optional[str] = None,
77
+ ):
78
+ """Wrapper for pjit."""
79
+ del backend
80
+ return jax_pjit(
81
+ fun,
82
+ in_axis_resources,
83
+ out_axis_resources,
84
+ static_argnums=static_argnums,
85
+ donate_argnums=donate_argnums,
86
+ )
87
+
88
+
89
+ # pjit wrappers for cpu fallback.
90
+ # -----------------------------------------------------------------------------
91
+ # TODO(levskaya): upstream this fallback behavior to jax pjit.
92
+ def pjit_with_cpu_fallback(
93
+ fun: Callable, # pylint: disable=g-bare-generic
94
+ in_axis_resources,
95
+ out_axis_resources,
96
+ static_argnums: Union[int, Sequence[int]] = (),
97
+ donate_argnums: Union[int, Sequence[int]] = (),
98
+ backend: Optional[str] = None,
99
+ ):
100
+ """Wrapper for pjit that calls normal jit on cpu."""
101
+ if jax.devices(backend)[0].platform == "cpu":
102
+ return jax.jit(fun, static_argnums=static_argnums, donate_argnums=donate_argnums)
103
+ else:
104
+ return jax_pjit(
105
+ fun,
106
+ in_axis_resources,
107
+ out_axis_resources,
108
+ static_argnums=static_argnums,
109
+ donate_argnums=donate_argnums,
110
+ )
111
+
112
+
113
+ def with_sharding_constraint(x, axis_resources):
114
+ """Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit."""
115
+ if jax.devices()[0].platform == "cpu" or not global_mesh_defined():
116
+ return x
117
+ else:
118
+ return jax.experimental.pjit.with_sharding_constraint(x, axis_resources)
119
+
120
+
121
+ # pjit Mesh creation functions.
122
+ # -----------------------------------------------------------------------------
123
+ def bounds_from_last_device(last_device: JaxDevice) -> HardwareMesh:
124
+ """Get the bound from the given last device."""
125
+ # Must be passed the device at the highest-coordinate corner of the
126
+ # relevant mesh, which is a requirement we know is satisfied by the last
127
+ # device in jax.devices().
128
+ if hasattr(last_device, "coords"):
129
+ x, y, z = last_device.coords
130
+ return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
131
+ else:
132
+ # On non-TPU platforms, the "mesh" is hosts x devices per host in order
133
+ # to take advantage of faster within-host interconnect.
134
+ return jax.host_count(), jax.local_device_count()
135
+
136
+
137
+ def get_coords(device: JaxDevice) -> HardwareMesh:
138
+ """Returns the coordinates of the given device."""
139
+ if hasattr(device, "coords"):
140
+ return (*device.coords, device.core_on_chip)
141
+ return (device.process_index, device.id % jax.local_device_count())
142
+
143
+
144
+ def global_mesh_defined():
145
+ """Checks if global xmap/pjit mesh resource environment is defined."""
146
+ maps_env = jax.experimental.maps.thread_resources.env
147
+ return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison
148
+
149
+
150
+ def get_mesh(
151
+ model_parallel_submesh: HardwareMesh,
152
+ input_devices: Sequence[JaxDevice] = (),
153
+ input_local_devices: Sequence[JaxDevice] = (),
154
+ tile_by_host_if_needed: bool = True,
155
+ backend: Optional[str] = None,
156
+ ) -> Mesh:
157
+ """Construct an xmap/pjit Mesh for the given model-parallel submesh.
158
+
159
+ The resulting mesh has two resource axes: 'model', with the provided submesh
160
+ shape, and 'data', which covers the rest of the mesh.
161
+
162
+ Args:
163
+ model_parallel_submesh: a HardwareMesh spec, namely (x,y,z,core) on TPU for
164
+ a single model-parallel replica's "tile" in the physical device mesh. The
165
+ first three elements (`x`, `y`, and `z`) should be factors of the pod
166
+ slice; e.g., if you are using df_4x8, then `x` should be a factor of 4
167
+ (one of 1, 2, 4), `y` should be a factor of 8 (one of 1, 2, 4, 8), and `z`
168
+ must be 1, because TPU v3 slices are only 2D. `z` can be >1 for TPU v4
169
+ (and maybe later TPUs) that allow 3D slices. `core` is the number of cores
170
+ to use from each TPU node. As communication is usually fastest inside the
171
+ same node, if you need a tile of more than 1 core, then
172
+ you should first increase `core`: e.g., for TPU v3, (1,1,1,2) is better
173
+ than (2,1,1,1). To pick a good spec, try a few possible values until you
174
+ get high TPU utilization.
175
+ input_devices: the devices to use, will use jax.devices() if this is not
176
+ set.
177
+ input_local_devices: the local devices to use, will use jax.local_devices()
178
+ if this is not set.
179
+ tile_by_host_if_needed: JAX currently requires that the parts of any sharded
180
+ array that are located on one host's local devices form a single
181
+ contiguous slice. A best effort will be made to achieve this without
182
+ "tiling" the device assignment over hosts (which can reduce XLA collective
183
+ performance). If this flag is True, then the device assignment will be
184
+ tiled over hosts if necessary to satisfy this constraint and create a
185
+ buildable mesh; if false, mesh construction will fail instead.
186
+ backend: get devices from the pinned backend, if specified. This is
187
+ useful for explicitly specifying the devices other than relying on
188
+ jax_platform_name.
189
+
190
+ Returns:
191
+ A xmap / pjit Mesh containing the virtual device mesh with data, model axes.
192
+ """
193
+ input_devices = input_devices or jax.devices(backend)
194
+ input_local_devices = input_local_devices or jax.local_devices(0, backend)
195
+ # Sort input_devices based on coords, as backends might not return devices
196
+ # in order.
197
+ last_device = sorted(input_devices, key=get_coords)[-1]
198
+ last_input_local_devices = sorted(input_local_devices, key=get_coords)[-1]
199
+ logging.info(
200
+ "last device coords : %r\nlast local device coords: %r",
201
+ get_coords(last_device),
202
+ get_coords(last_input_local_devices),
203
+ )
204
+ global_hardware_mesh = bounds_from_last_device(last_device)
205
+ mesh_ndim = len(global_hardware_mesh)
206
+ local_hardware_mesh = bounds_from_last_device(last_input_local_devices)
207
+ mesh_err = (
208
+ f"each dimension of the model parallel submesh {model_parallel_submesh} "
209
+ "must be a factor of the corresponding dimension of the global device "
210
+ f"mesh {global_hardware_mesh}"
211
+ )
212
+ assert not any(g % m for g, m in zip(global_hardware_mesh, model_parallel_submesh)), mesh_err
213
+ assert not any(g % l for g, l in zip(global_hardware_mesh, local_hardware_mesh))
214
+ devices = np.empty(global_hardware_mesh, dtype=object)
215
+ for device in input_devices:
216
+ device_coords = get_coords(device)
217
+ devices[device_coords] = device
218
+ tile_by_host = tile_by_host_if_needed
219
+ if len(global_hardware_mesh) == 4:
220
+ # enable contiguous local chunks without host tiling by making Z major
221
+ global_hardware_mesh = typing.cast(Tuple[int, int, int, int], global_hardware_mesh)
222
+ model_parallel_submesh = typing.cast(Tuple[int, int, int, int], model_parallel_submesh)
223
+ gx, gy, gz, gc = global_hardware_mesh
224
+ mx, my, mz, mc = model_parallel_submesh
225
+ if (mx == gx > 1 and my == mz == 1) or (mx == 1 and my == gy > 1 and mz == gz > 1):
226
+ logging.info("ensuring YZ plane has a Z-major device order")
227
+ # YZ should be ZY
228
+ assert mc == gc, (mc, gc)
229
+ global_hardware_mesh = gx, gz, gy, gc
230
+ model_parallel_submesh = mx, mz, my, mc
231
+ devices = devices.swapaxes(1, 2)
232
+ tile_by_host = False
233
+ if (my == gy > 1 and mx == mz == 1) or (my == 1 and mx == gx > 1 and mz == gz > 1):
234
+ logging.info("ensuring XZ plane has a Z-major device order")
235
+ # XZ should be ZX
236
+ assert mc == gc, (mc, gc)
237
+ global_hardware_mesh = gz, gy, gx, gc
238
+ model_parallel_submesh = mz, my, mx, mc
239
+ devices = devices.swapaxes(0, 2)
240
+ tile_by_host = False
241
+ if tile_by_host:
242
+ logging.warning(
243
+ "Tiling device assignment mesh by hosts, which may lead to "
244
+ "reduced XLA collective performance. To avoid this, modify "
245
+ "the model parallel submesh or run with more tasks per host."
246
+ )
247
+ tile_err = (
248
+ "to tile the mesh by hosts, each dimension of the model parallel "
249
+ "submesh must be either a factor or a multiple of the corresponding "
250
+ "dimension of the per-host submesh"
251
+ )
252
+
253
+ def dh_dd_mh_md(g: int, m: int, l: int) -> Tuple[int, int, int, int]:
254
+ """Split a global mesh dimension into four tiling components.
255
+
256
+ Args:
257
+ g: global mesh bounds dimension size
258
+ m: model-parallel submesh bounds dimension size
259
+ l: local submesh bounds dimension size
260
+
261
+ Returns:
262
+ The resulting tuple divides the dimension into the hosts component of
263
+ the data-parallel submesh, the devices component of the data-parallel
264
+ submesh, the hosts component of the model-parallel submesh, and the
265
+ devices component of the model-parallel submesh.
266
+ """
267
+ d = g // m
268
+ if m >= l:
269
+ assert not m % l, tile_err
270
+ return (d, 1, m // l, l)
271
+ else:
272
+ assert not l % m, tile_err
273
+ return (d // (l // m), l // m, 1, m)
274
+
275
+ # e.g. [(x_data_hosts, x_data_devs, x_model_hosts, x_model_devs), ...]
276
+ dh_dd_mh_md_tups = map(
277
+ dh_dd_mh_md,
278
+ global_hardware_mesh,
279
+ model_parallel_submesh,
280
+ local_hardware_mesh,
281
+ )
282
+ # reshape to e.g. (x_dh, x_dd, x_mh, x_md, y_dh, ...)
283
+ devices = devices.reshape(*(s for t in dh_dd_mh_md_tups for s in t)) # pylint: disable=g-complex-comprehension
284
+ # TODO(jekbradbury): reorder local subgroups for ring locality
285
+ # Transpose to [data_host], [data_device], [model_host], [model_device]
286
+ # block ordering e.g. (x_dh, y_dh, ..., x_dd, y_dd, ...)
287
+ devices = devices.transpose(
288
+ *(4 * i for i in range(mesh_ndim)),
289
+ *(4 * i + 1 for i in range(mesh_ndim)),
290
+ *(4 * i + 2 for i in range(mesh_ndim)),
291
+ *(4 * i + 3 for i in range(mesh_ndim)),
292
+ )
293
+ else:
294
+ # e.g. [(x_data, x_model), (y_data, y_model), ...]
295
+ model_data_tups = [(g // m, m) for g, m in zip(global_hardware_mesh, model_parallel_submesh)]
296
+ # reshape to e.g. (x_data, x_model, y_data, y_model...)
297
+ devices = devices.reshape(*(s for t in model_data_tups for s in t)) # pylint: disable=g-complex-comprehension
298
+ # TODO(jekbradbury): reorder small subgroups for ring locality
299
+ # transpose to e.g. (x_data, y_data, ..., x_model, ...)
300
+ devices = devices.transpose(*(2 * i for i in range(mesh_ndim)), *(2 * i + 1 for i in range(mesh_ndim)))
301
+ # reshape to (data, model)
302
+ devices = devices.reshape(-1, np.prod(model_parallel_submesh))
303
+ global_mesh = Mesh(devices, ["data", "model"])
304
+ logging.info("global_mesh axis_names: %s", global_mesh.axis_names)
305
+ logging.info("global_mesh devices: %s", global_mesh.devices)
306
+ logging.info("global_mesh devices shape: %s", global_mesh.devices.shape)
307
+ return global_mesh
308
+
309
+
310
+ def get_cpu_mesh() -> Mesh:
311
+ """Trivial mesh for CPU Testing."""
312
+ devices = np.empty((jax.host_count(), jax.local_device_count()), dtype=object)
313
+ for device in jax.devices():
314
+ devices[device.process_index, device.id % jax.local_device_count()] = device
315
+ return Mesh(devices, ["data", "model"])
316
+
317
+
318
+ def get_gpu_mesh(num_partitions: int) -> Mesh:
319
+ """Mesh for GPUs that preferentially places 'model' on NVLink."""
320
+ nvlink_size = jax.local_device_count()
321
+ dcn_size = jax.process_count()
322
+ nvlink_mp = min(num_partitions, nvlink_size)
323
+ nvlink_dp, extra1 = divmod(nvlink_size, nvlink_mp)
324
+ dcn_mp, extra2 = divmod(num_partitions, nvlink_mp)
325
+ assert not (
326
+ extra1 or extra2
327
+ ), "number of partitions on GPU must be a factor or multiple of the number of local devices"
328
+ dcn_dp = dcn_size // dcn_mp
329
+
330
+ devices = create_hybrid_device_mesh(
331
+ mesh_shape=[nvlink_dp, nvlink_mp],
332
+ dcn_mesh_shape=[dcn_dp, dcn_mp],
333
+ process_is_granule=True,
334
+ )
335
+
336
+ global_mesh = Mesh(devices, ["data", "model"])
337
+ logging.info("global_mesh axis_names: %s", global_mesh.axis_names)
338
+ logging.info("global_mesh devices: %s", global_mesh.devices)
339
+ return global_mesh
340
+
341
+
342
+ def default_mesh(
343
+ num_partitions: int,
344
+ model_parallel_submesh: Optional[HardwareMesh] = None,
345
+ backend: Optional[str] = None,
346
+ ) -> Mesh:
347
+ """Attempt to return a default mesh for simple cases.
348
+
349
+ Args:
350
+ num_partitions: number of partitions to use, will be ignored if
351
+ model_parallel_submesh is provided.
352
+ model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use as
353
+ the model-parallel device tile.
354
+ backend: get devices from the pinned backend, if specified. This is useful
355
+ for explicitly specifying the devices other than relying on
356
+ jax_platform_name.
357
+
358
+ Returns:
359
+ xmap/pjit 2D Mesh with 'data', 'model' mesh axes.
360
+ """
361
+ last_device = jax.devices(backend)[-1]
362
+ platform = last_device.platform
363
+ device_kind = last_device.device_kind
364
+ bounds = bounds_from_last_device(last_device)
365
+
366
+ if model_parallel_submesh:
367
+ return get_mesh(model_parallel_submesh, backend=backend)
368
+
369
+ if platform == "cpu":
370
+ return get_cpu_mesh()
371
+ elif platform == "gpu":
372
+ return get_gpu_mesh(num_partitions)
373
+
374
+ mps = None
375
+ if device_kind in ("TPU v2", "TPU v3"):
376
+ if num_partitions == 1:
377
+ mps = (1, 1, 1, 1)
378
+ elif num_partitions == 2:
379
+ mps = (1, 1, 1, 2)
380
+ elif num_partitions == 4:
381
+ mps = (2, 1, 1, 2)
382
+ elif num_partitions == 8:
383
+ mps = (2, 2, 1, 2)
384
+ elif num_partitions == 16:
385
+ mps = (4, 2, 1, 2)
386
+ # assume the use of megacore on TPU v4
387
+ elif (device_kind == "TPU v4" or device_kind == "TPU v4 lite") and bounds[3] == 1:
388
+ if num_partitions == 1:
389
+ mps = (1, 1, 1, 1)
390
+ elif num_partitions == 2:
391
+ mps = (1, 2, 1, 1)
392
+ elif num_partitions == 4:
393
+ if bounds[0] >= 4:
394
+ mps = (4, 1, 1, 1)
395
+ else:
396
+ mps = (2, 2, 1, 1)
397
+ elif num_partitions == 8:
398
+ if bounds[2] >= 8:
399
+ mps = (1, 1, 8, 1)
400
+ else:
401
+ mps = (4, 2, 1, 1)
402
+ elif num_partitions == 16:
403
+ if bounds[2] >= 16:
404
+ mps = (1, 1, 16, 1)
405
+ elif bounds[0] >= 8:
406
+ mps = (8, 2, 1, 1)
407
+ elif bounds[0] >= 4:
408
+ mps = (4, 4, 1, 1)
409
+ else:
410
+ mps = (2, 2, 4, 1)
411
+
412
+ if mps is None:
413
+ raise ValueError(
414
+ "No default mesh for this configuration: specify " "config.model_parallel_submesh explicitly."
415
+ )
416
+ return get_mesh(mps, backend=backend)
417
+
418
+
419
+ # Data chunking helper.
420
+ # -----------------------------------------------------------------------------
421
+ @dataclasses.dataclass
422
+ class LocalChunkInfo:
423
+ # The logical slice of an array located on this host's local devices.
424
+ slice: Tuple[slice, ...]
425
+ # A unique index for this host/local chunk among chunks with the same slice.
426
+ replica_id: int
427
+
428
+
429
+ class LocalChunker:
430
+ """Utility class to aid chunking of sharded arrays in multihost settings."""
431
+
432
+ def __init__(self, global_mesh: Mesh):
433
+ self.global_mesh = global_mesh
434
+ local_mesh = global_mesh.local_mesh
435
+ first_local_device = local_mesh.devices.reshape(-1)[0]
436
+ host_location = collections.OrderedDict(
437
+ zip(
438
+ global_mesh.shape.keys(),
439
+ list(zip(*np.nonzero(global_mesh.devices == first_local_device)))[0],
440
+ )
441
+ )
442
+ self.num_chunks = collections.OrderedDict()
443
+ self.chunk_ids = collections.OrderedDict()
444
+ self.mesh_axes = list(global_mesh.shape.keys())
445
+ for mesh_axis in self.mesh_axes:
446
+ num_devices_per_chunk = local_mesh.shape[mesh_axis]
447
+ self.num_chunks[mesh_axis] = global_mesh.shape[mesh_axis] // num_devices_per_chunk
448
+ self.chunk_ids[mesh_axis] = host_location[mesh_axis] // num_devices_per_chunk
449
+
450
+ def get_local_chunk_info(
451
+ self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]]
452
+ ) -> LocalChunkInfo:
453
+ """Get the local chunk info for a given array shape and sharded axes.
454
+
455
+ Args:
456
+ global_shape: the global, unsharded shape of the array to chunk.
457
+ mesh_axes: a sequence of names (or None) of equal rank to `global_shape`
458
+ that specifies which mesh dimensions the array is sharded along.
459
+
460
+ Returns:
461
+ LocalChunkInfo containing the logical slices of the array found on this
462
+ host's local devices, as well as the replica index for this chunk among
463
+ chunks with the same slice. The latter is used to determine which
464
+ host should write this chunk during checkpointing.
465
+ """
466
+ local_slice = [slice(None) for dim in global_shape]
467
+ sharded_mesh_axes = set()
468
+ for i, (mesh_axis, size) in enumerate(zip(mesh_axes, global_shape)):
469
+ if not mesh_axis:
470
+ continue
471
+ sharded_mesh_axes.add(mesh_axis)
472
+ if not isinstance(mesh_axis, str):
473
+ raise NotImplementedError("TODO(jekbradbury)")
474
+ chunk_id = self.chunk_ids[mesh_axis]
475
+ chunk_size = size // self.num_chunks[mesh_axis]
476
+ local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size)
477
+
478
+ replicated_mesh_axes = [mesh_axis for mesh_axis in self.mesh_axes if mesh_axis not in sharded_mesh_axes]
479
+ replica_id = 0
480
+ for mesh_axis in replicated_mesh_axes:
481
+ chunk_id = self.chunk_ids[mesh_axis]
482
+ replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id
483
+
484
+ return LocalChunkInfo(tuple(local_slice), replica_id)
485
+
486
+
487
+ def standard_logical_axis_rules(
488
+ activation_partitioning_dims: int = 1,
489
+ parameter_partitioning_dims: int = 1,
490
+ additional_rules: Optional[LogicalAxisRules] = None,
491
+ ) -> LogicalAxisRules:
492
+ """Default sharding rules for T5X model in terms of logical axis names.
493
+
494
+ Args:
495
+ activation_partitioning_dims: enables 2-D activation sharding when set to 2.
496
+ parameter_partitioning_dims: enables 2-D parameter sharding when set to 2.
497
+ additional_rules: additional rules (a sequence of tuples) that will be
498
+ appended to the standard rules.
499
+
500
+ Returns:
501
+ Sequence of logical axis rules
502
+ """
503
+ logging.info(
504
+ "`activation_partitioning_dims` = %d, `parameter_partitioning_dims` = %d",
505
+ activation_partitioning_dims,
506
+ parameter_partitioning_dims,
507
+ )
508
+
509
+ if activation_partitioning_dims == 1 and parameter_partitioning_dims == 1:
510
+ rules = [
511
+ ("batch", "data"),
512
+ ("vocab", "model"),
513
+ ("embed", None),
514
+ ("mlp", "model"),
515
+ ("heads", "model"),
516
+ ("kv", None),
517
+ ("joined_kv", "model"), # joined heads+kv dim in 2D attn param layouts
518
+ ]
519
+ elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 1:
520
+ rules = [
521
+ ("batch", "data"),
522
+ ("vocab", "model"),
523
+ ("mlp", "model"),
524
+ ("heads", "model"),
525
+ ("kv", None),
526
+ ("joined_kv", "model"),
527
+ ("embed", "model"),
528
+ ]
529
+ elif activation_partitioning_dims == 1 and parameter_partitioning_dims == 2:
530
+ rules = [
531
+ ("batch", "data"),
532
+ ("vocab", "model"),
533
+ ("mlp", "model"),
534
+ ("heads", "model"),
535
+ ("kv", None),
536
+ ("joined_kv", "model"),
537
+ ("embed", "data"),
538
+ ]
539
+ elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 2:
540
+ rules = [
541
+ ("batch", "data"),
542
+ ("vocab", "model"),
543
+ ("mlp", "model"),
544
+ ("heads", "model"),
545
+ ("kv", None),
546
+ ("joined_kv", "model"),
547
+ ("embed", "model"),
548
+ ("embed", "data"),
549
+ ]
550
+ else:
551
+ raise ValueError(
552
+ f"`activation_partitioning_dims` = {activation_partitioning_dims} "
553
+ f"`parameter_partitioning_dims` = {parameter_partitioning_dims} "
554
+ "is not supported."
555
+ )
556
+
557
+ # Add the common rules for the replicated logical axes names.
558
+ replicated_rules = [
559
+ ("relpos_buckets", None),
560
+ ("abspos_buckets", None),
561
+ ("length", None),
562
+ ("layers", None),
563
+ ("stack", None),
564
+ ("mlp_activations", None),
565
+ ]
566
+ rules.extend(replicated_rules)
567
+
568
+ if additional_rules:
569
+ rules.extend(additional_rules)
570
+
571
+ return rules
572
+
573
+
574
+ # NB: This needs to be top-level for the jax compilation cache.
575
+ def _id_fn(x, ix):
576
+ """Identity function for copying parameters to the devices, sharded."""
577
+ # A pure identity such as `lambda x, *: x` can get optimized away, so we
578
+ # include a random.split as a cheap function that cannot be optimized away.
579
+ y = random.split(random.PRNGKey(jnp.array(ix, dtype=jnp.uint32)))
580
+ return x, y
581
+
582
+
583
+ @dataclasses.dataclass
584
+ class DataLayout:
585
+ """Represents data layout for the partitioned model."""
586
+
587
+ batch_size: int
588
+ shard_id: int
589
+ num_shards: int
590
+ is_first_host_in_replica_set: bool
591
+
592
+
593
+ PartitionedCallable = Callable[..., Any]
594
+ CompiledPartitionedCallable = Callable[..., Any]
595
+
596
+
597
+ class BasePartitioner(metaclass=abc.ABCMeta):
598
+ """Interface for partitioning computations across hardware devices."""
599
+
600
+ def __init__(
601
+ self,
602
+ num_partitions: Optional[int] = None,
603
+ model_parallel_submesh: Optional[HardwareMesh] = None,
604
+ params_on_devices: bool = True,
605
+ backend: Optional[str] = None,
606
+ ):
607
+ """Configures the partitioner.
608
+
609
+ Args:
610
+ num_partitions: the number of partitions to use. Ignored if
611
+ `model_parallel_submesh` is provided.
612
+ model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use
613
+ as the model-parallel device tile. This submesh is used for the larger
614
+ of the two parameter dimensions, and, if 2-D activation sharding is
615
+ enabled, for the model dimension of activations. The rest of the mesh is
616
+ used for data parallelism and, if 2-D parameter sharding is enabled, the
617
+ other parameter dimension.
618
+ params_on_devices: whether to keep the params on devices, if False -
619
+ params stay in the host memory. Note that some partitioners might ignore
620
+ this setting, for example if they don't support storing all params on
621
+ device memory.
622
+ backend: get devices from the pinned backend, if specified. This is useful
623
+ for explicitly specifying the devices other than relying on
624
+ jax_platform_name.
625
+ """
626
+
627
+ if not num_partitions and not model_parallel_submesh:
628
+ raise ValueError("At least one of `num_partitions` or " "`model_parallel_submesh` must be set.")
629
+
630
+ if model_parallel_submesh is not None and len(model_parallel_submesh) != 4:
631
+ logging.error(
632
+ (
633
+ "`model_parallel_submesh` must be either None or a 4-tuple. Got"
634
+ " `model_parallel_submesh`=%s. A ValueError will be raised"
635
+ " beginning March 1, 2022."
636
+ ),
637
+ model_parallel_submesh,
638
+ )
639
+
640
+ if bool(num_partitions) and bool(model_parallel_submesh):
641
+ logging.error(
642
+ (
643
+ "At most one of `num_partitions` or `model_parallel_submesh` can be"
644
+ " set. Got `num_partitions=%s` and `model_parallel_submesh`=%s. A"
645
+ " ValueError will be raised beginning March 21, 2022."
646
+ ),
647
+ num_partitions,
648
+ model_parallel_submesh,
649
+ )
650
+
651
+ self._num_partitions = num_partitions
652
+ self._model_parallel_submesh = model_parallel_submesh
653
+ self._params_on_devices = params_on_devices
654
+ self._data_axis = "data"
655
+ self._backend = backend
656
+
657
+ @property
658
+ def mesh(self) -> Mesh:
659
+ raise NotImplementedError
660
+
661
+ @property
662
+ def data_partition_spec(self) -> PartitionSpec:
663
+ return PartitionSpec(self._data_axis)
664
+
665
+ def get_data_layout(self, batch_size: Optional[int] = None, host_index: Optional[int] = None) -> DataLayout:
666
+ """Returns filled `DataLayout` based on the partitioned model layout.
667
+
668
+ Args:
669
+ batch_size: if set, indicates the requested batch size. The exception will
670
+ be raised if this batch size is not compatible with the layout. If not
671
+ set, the batch size is inferred from the layout.
672
+ host_index: indicates the host index to use for the calculations, if not
673
+ set - use JAX-provided one. Should be in [0, num_hosts) interval and the
674
+ order should match the order of corresponding CPU devices in
675
+ `jax.devices()`.
676
+
677
+ Returns:
678
+ Filled `DataLayout` structure.
679
+ """
680
+ if host_index is not None:
681
+ raise NotImplementedError("Explicit host_index is not yet implemented.")
682
+ if self._data_axis is None:
683
+ return DataLayout(
684
+ batch_size=batch_size,
685
+ shard_id=0,
686
+ num_shards=1,
687
+ is_first_host_in_replica_set=(jax.process_index() == 0),
688
+ )
689
+ mesh_size = self._local_chunker.global_mesh.shape[self._data_axis]
690
+ batch_size = batch_size or mesh_size
691
+ if batch_size % mesh_size:
692
+ raise ValueError(
693
+ f"Batch size ({batch_size}) must be divisible by corresponding " f"mesh size ({mesh_size})."
694
+ )
695
+ num_shards = self._local_chunker.num_chunks[self._data_axis]
696
+ if batch_size % num_shards:
697
+ raise ValueError(f"Batch size ({batch_size}) must be divisible by number of " f"replicas ({num_shards}).")
698
+ replica_id = self._local_chunker.get_local_chunk_info((batch_size,), [self._data_axis]).replica_id
699
+ return DataLayout(
700
+ batch_size=int(batch_size),
701
+ shard_id=int(self._local_chunker.chunk_ids[self._data_axis]),
702
+ num_shards=int(num_shards),
703
+ is_first_host_in_replica_set=(replica_id == 0),
704
+ )
705
+
706
+ def get_local_chunk_info(
707
+ self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]]
708
+ ) -> LocalChunkInfo:
709
+ """Returns the local chunk info for a given array shape and sharded axes."""
710
+ return self._local_chunker.get_local_chunk_info(global_shape, mesh_axes)
711
+
712
+ @property
713
+ def params_on_devices(self):
714
+ return self._params_on_devices
715
+
716
+ def move_params_to_devices(self, train_state: TrainState, train_state_axes: TrainState) -> TrainState:
717
+ """Moves the optimizer parameters to devices."""
718
+ p_id_fn = self.partition(
719
+ _id_fn,
720
+ in_axis_resources=(train_state_axes, None),
721
+ out_axis_resources=(train_state_axes, None),
722
+ donate_argnums=(0,),
723
+ )
724
+ if jax.config.jax_array and jax.process_count() > 1:
725
+ train_state = multihost_utils.host_local_array_to_global_array(train_state, self.mesh, train_state_axes)
726
+ train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32))
727
+ return train_state
728
+
729
+ @property
730
+ @abc.abstractmethod
731
+ def _local_chunker(self):
732
+ """Returns the chunker that matches the parameters of this partitioner."""
733
+ raise NotImplementedError
734
+
735
+ def get_logical_axes(self, train_state: TrainState) -> TrainState:
736
+ """Returns a copy of TrainState with Optional[AxisNames] as leaves."""
737
+ # By default, return None for the logical axes.
738
+ return train_state.restore_state(jax.tree_map(lambda x: None, train_state.state_dict()))
739
+
740
+ def get_mesh_axes(self, train_state: TrainState) -> TrainState:
741
+ """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves."""
742
+ raise NotImplementedError
743
+
744
+ @abc.abstractmethod
745
+ def partition(
746
+ self,
747
+ fn: Callable, # pylint: disable=g-bare-generic
748
+ in_axis_resources,
749
+ out_axis_resources,
750
+ static_argnums: Union[int, Sequence[int]] = (),
751
+ donate_argnums: Union[int, Sequence[int]] = (),
752
+ ) -> PartitionedCallable:
753
+ """Partitions the computation using partitioner-specific implementation.
754
+
755
+ Args:
756
+ fn: the function to partition.
757
+ in_axis_resources: Pytree of structure matching that of arguments to `fn`,
758
+ with all actual arguments replaced by resource assignment
759
+ specifications. It is also valid to specify a pytree prefix (e.g. one
760
+ value in place of a whole subtree), in which case the leaves get
761
+ broadcast to all values in that subtree.
762
+ The valid resource assignment specifications are:
763
+ `None`: in which case the value will be replicated on all devices
764
+ `PartitionSpec`: a tuple of length at most equal to the rank of the
765
+ partitioned value. Each element can be a `None`, a mesh axis or a
766
+ tuple of mesh axes, and specifies the set of resources assigned to
767
+ partition the value's dimension matching its position in the spec.
768
+ out_axis_resources: Like `in_axis_resources`, but specifies resource
769
+ assignment for function outputs.
770
+ static_argnums: an optional int or collection of ints that specify which
771
+ positional arguments to treat as static (compile-time constant) in the
772
+ partitioned function.
773
+ donate_argnums: an optional int or collection of ints that specify which
774
+ argument buffers are "donated" to the computation. It is safe to donate
775
+ argument buffers if you no longer need them once the computation has
776
+ finished.
777
+
778
+ Returns:
779
+ A partitioned version of the input function.
780
+ """
781
+ raise NotImplementedError
782
+
783
+ @abc.abstractmethod
784
+ def compile(self, partitioned_fn: PartitionedCallable, *args) -> CompiledPartitionedCallable:
785
+ """Compiles and returns the partitioned function, or the original.
786
+
787
+ Args:
788
+ partitioned_fn: The partitioned function.
789
+ *args: Sample arguments to the partitioned function matching the input
790
+ shapes that will be passed to the compiled function.
791
+
792
+ Returns:
793
+ The compiled function, or the original if this partitioner does not
794
+ support compilation.
795
+ """
796
+ raise NotImplementedError
797
+
798
+
799
+ class PjittedFnWithContext(PartitionedCallable):
800
+ """Wraps pjitted function to apply the appropriate contexts."""
801
+
802
+ def __init__(
803
+ self,
804
+ pjitted_fn,
805
+ partition_mesh: Mesh,
806
+ logical_axis_rules: flax_partitioning.LogicalRules = (),
807
+ ):
808
+ self._pjitted_fn = pjitted_fn
809
+ self._mesh = partition_mesh
810
+ self._logical_axis_rules = logical_axis_rules
811
+
812
+ def __call__(self, *args):
813
+ with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules):
814
+ return self._pjitted_fn(*args)
815
+
816
+ def lower(self, *args):
817
+ with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules):
818
+ return self._pjitted_fn.lower(*args)
819
+
820
+
821
+ class BasePjitPartitioner(BasePartitioner):
822
+ """Partitioner that uses T5X version of jax.pjit."""
823
+
824
+ @cached_property
825
+ def _local_chunker(self) -> LocalChunker:
826
+ return LocalChunker(self.mesh)
827
+
828
+ @cached_property
829
+ def mesh(self) -> Mesh:
830
+ return default_mesh(self._num_partitions, self._model_parallel_submesh, self._backend)
831
+
832
+ def partition(
833
+ self,
834
+ fn: Callable, # pylint: disable=g-bare-generic
835
+ in_axis_resources,
836
+ out_axis_resources,
837
+ static_argnums: Union[int, Sequence[int]] = (),
838
+ donate_argnums: Union[int, Sequence[int]] = (),
839
+ ) -> PjittedFnWithContext:
840
+ pjitted = pjit(
841
+ fn,
842
+ in_axis_resources=in_axis_resources,
843
+ out_axis_resources=out_axis_resources,
844
+ static_argnums=static_argnums,
845
+ donate_argnums=donate_argnums,
846
+ backend=self._backend,
847
+ )
848
+
849
+ return PjittedFnWithContext(pjitted, self.mesh)
850
+
851
+ def compile(self, partitioned_fn: PjittedFnWithContext, *args) -> CompiledPartitionedCallable:
852
+ return partitioned_fn.lower(*args).compile()
853
+
854
+
855
+ class PjitPartitioner(BasePjitPartitioner):
856
+ """Partitioner that uses named axes and jax.pjit."""
857
+
858
+ def __init__(
859
+ self,
860
+ num_partitions: Optional[int] = None,
861
+ model_parallel_submesh: Optional[HardwareMesh] = None,
862
+ params_on_devices: bool = True,
863
+ backend: Optional[str] = None,
864
+ logical_axis_rules: Optional[LogicalAxisRules] = None,
865
+ use_cpu_pjit: Optional[bool] = False,
866
+ ):
867
+ """PjitPartitioner constructor.
868
+
869
+ See https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/partitioning for details.
870
+
871
+ Args:
872
+ num_partitions: an integer that specifies the size of the model parallel
873
+ submesh to be automatically selected for the current topology. See
874
+ `model_parallel_submesh` for details on how this submesh is used.
875
+ Mutually exlusive with `model_parallel_submesh`.
876
+ model_parallel_submesh: is a 4-tuple that specifies the `(x, y, z, c)`
877
+ submesh model-parallel device tile, an axis of accelerator parallelism
878
+ orthogonal to data parallelism. Array axes in a model's parameters or
879
+ activations can be sharded over this submesh using axis rules (see
880
+ `logical_axis_rules`) that map them to 'model'. The effective number of
881
+ model sub-partitions is equal to `np.prod(model_parallel_submesh)` and
882
+ must evenly divide the total number of devices (i.e.,
883
+ `jax.device_count() % np.prod(model_parallel_submesh) == 0`). The rest
884
+ of the TPU mesh is the data parallel submesh, providing
885
+ `jax.device_count() // np.prod(model_parallel_submesh)` partitions. It
886
+ is used for data (batch) parallelism and to shard other array axes that
887
+ are mapped to 'data'. This argument is mutually exclusive with
888
+ `num_partitions`.
889
+ params_on_devices: whether to keep the params on devices, if False -
890
+ params stay in the host memory. Note that some partitioners might ignore
891
+ this setting, for example if they don't support storing all params on
892
+ device memory.
893
+ backend: get devices from the pinned backend, if specified. This is
894
+ useful for explicitly specifying the devices other than relying on
895
+ jax_platform_name.
896
+ logical_axis_rules: a priority-ordered sequence of KV tuples that maps
897
+ logical axis names to either `None` (not sharded), 'model' (to shard
898
+ across the model-parallel submesh), or 'data' (to shard across the
899
+ data-parallel submesh).
900
+ use_cpu_pjit: enables wrapper function for pjit which just jits the
901
+ function if using CPU backend.
902
+ """
903
+ super().__init__(
904
+ num_partitions=num_partitions,
905
+ model_parallel_submesh=model_parallel_submesh,
906
+ params_on_devices=params_on_devices,
907
+ backend=backend,
908
+ )
909
+ if logical_axis_rules is None:
910
+ logical_axis_rules = standard_logical_axis_rules()
911
+ self._logical_axis_rules = tuple(logical_axis_rules)
912
+ (self._data_axis,) = flax_partitioning.logical_to_mesh_axes(["batch"], logical_axis_rules)
913
+ self._use_cpu_pjit = use_cpu_pjit
914
+
915
+ def partition(
916
+ self,
917
+ fn: Callable, # pylint: disable=g-bare-generic
918
+ in_axis_resources,
919
+ out_axis_resources,
920
+ static_argnums: Union[int, Sequence[int]] = (),
921
+ donate_argnums: Union[int, Sequence[int]] = (),
922
+ ) -> PjittedFnWithContext:
923
+ """Partitions the function using jax.pjit."""
924
+ if self._use_cpu_pjit:
925
+ pjit_fn = pjit_with_cpu_fallback
926
+ else:
927
+ pjit_fn = pjit
928
+ pjitted = pjit_fn(
929
+ fn,
930
+ in_axis_resources=in_axis_resources,
931
+ out_axis_resources=out_axis_resources,
932
+ static_argnums=static_argnums,
933
+ donate_argnums=donate_argnums,
934
+ backend=self._backend,
935
+ )
936
+
937
+ return PjittedFnWithContext(pjitted, self.mesh, self._logical_axis_rules)
938
+
939
+ @property
940
+ def logical_axis_rules(self):
941
+ """Returns the logical axis rules."""
942
+ return self._logical_axis_rules
943
+
944
+ def get_logical_axes(self, train_state: TrainState) -> TrainState:
945
+ """Returns a copy of TrainState with Optional[AxisNames] as leaves."""
946
+ return train_state.as_logical_axes()
947
+
948
+ def get_mesh_axes(self, train_state: TrainState) -> TrainState:
949
+ """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves."""
950
+ logical_axes = self.get_logical_axes(train_state)
951
+
952
+ def _logical_to_mesh_axes(param_name, logical_axes):
953
+ if logical_axes is None:
954
+ return None
955
+ elif logical_axes is traverse_util.empty_node:
956
+ return traverse_util.empty_node
957
+ try:
958
+ return flax_partitioning.logical_to_mesh_axes(logical_axes, self._logical_axis_rules)
959
+ except ValueError as e:
960
+ raise ValueError(f"Failed to map logical axes for {param_name}") from e
961
+
962
+ flat_logical_axes = traverse_util.flatten_dict(logical_axes.state_dict(), keep_empty_nodes=True, sep="/")
963
+ flat_mesh_axes = {k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()}
964
+
965
+ return logical_axes.restore_state(traverse_util.unflatten_dict(flat_mesh_axes, sep="/"))
flax/distil_whisper/pipeline.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Whisper JAX pipeline compatible with Distil Whisper checkpoints. Copied from https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py"""
17
+
18
+ import math
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ import requests
24
+ import torch
25
+ from flax import jax_utils
26
+ from flax.core.frozen_dict import freeze
27
+ from flax.training.common_utils import shard
28
+ from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
29
+ from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
30
+ from transformers.pipelines.audio_utils import ffmpeg_read
31
+ from transformers.utils import logging
32
+
33
+ from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ class FlaxWhisperFeatureExtractor(WhisperFeatureExtractor):
40
+ def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
41
+ """
42
+ Compute the log-mel spectrogram of the provided audio using torch filters. Using the torch implementation
43
+ computes stft filter banks approx 5x faster than its numpy counterpart, which is the native implementation
44
+ in transformers, and matches to within 1e-5 abs tolerance.
45
+ """
46
+ waveform = torch.from_numpy(waveform).type(torch.float32)
47
+
48
+ window = torch.hann_window(self.n_fft)
49
+ stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
50
+ magnitudes = stft[..., :-1].abs() ** 2
51
+
52
+ mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
53
+ mel_spec = mel_filters.T @ magnitudes
54
+
55
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
56
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
57
+ log_spec = (log_spec + 4.0) / 4.0
58
+ return log_spec.numpy()
59
+
60
+
61
+ class FlaxWhisperPipeline:
62
+ def __init__(
63
+ self,
64
+ checkpoint="openai/whisper-large-v2",
65
+ dtype=jnp.float32,
66
+ batch_size=None,
67
+ max_length=None,
68
+ **kwargs,
69
+ ):
70
+ """
71
+ Args
72
+ checkpoint (`str`, *optional*, defaults to `"openai/whisper-large-v2"):
73
+ The Whisper checkpoint to use with the pipeline. Must be an available checkpoint on the Hugging Face Hub
74
+ with Flax weights.
75
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
76
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
77
+ `jax.numpy.bfloat16` (on TPUs). This can be used to enable half-precision inference on GPUs or TPUs.
78
+ If specified all the computation will be performed with the given `dtype`. **Note that this only
79
+ specifies the dtype of the computation and does not influence the dtype of model parameters.**
80
+ batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
81
+ The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
82
+ a batch size in the `__init__` method will be superseded by any batch size passed to the `__call__` method.
83
+ max_length (`int`, *optional*):
84
+ The maximum numbers of tokens to generate. Defaults to `model.config.max_length`.
85
+ """
86
+ self.checkpoint = checkpoint
87
+ self.dtype = dtype
88
+
89
+ self.feature_extractor = FlaxWhisperFeatureExtractor.from_pretrained(self.checkpoint)
90
+ self.tokenizer = WhisperTokenizerFast.from_pretrained(self.checkpoint)
91
+
92
+ self.model, self.params = FlaxWhisperForConditionalGeneration.from_pretrained(
93
+ self.checkpoint,
94
+ _do_init=False,
95
+ dtype=self.dtype,
96
+ **kwargs,
97
+ )
98
+
99
+ self.max_length = max_length if max_length is not None else self.model.generation_config.max_length
100
+ self.min_batch_size = jax.local_device_count()
101
+ self.batch_size = (
102
+ batch_size if batch_size is not None else self.min_batch_size
103
+ ) # we need a minimum of 1 batch per-device
104
+
105
+ def generate(
106
+ params,
107
+ input_features,
108
+ forced_decoder_ids,
109
+ return_timestamps,
110
+ num_beams,
111
+ length_penalty,
112
+ do_sample,
113
+ top_k,
114
+ temperature,
115
+ ):
116
+ output_ids = self.model.pipeline_generate(
117
+ input_features,
118
+ params=params,
119
+ forced_decoder_ids=forced_decoder_ids,
120
+ return_timestamps=return_timestamps,
121
+ max_length=self.max_length,
122
+ num_beams=num_beams,
123
+ length_penalty=length_penalty,
124
+ do_sample=do_sample,
125
+ top_k=top_k,
126
+ temperature=temperature,
127
+ )
128
+ return output_ids
129
+
130
+ self.params = jax_utils.replicate(self.params)
131
+ self.p_generate = jax.pmap(
132
+ generate,
133
+ "input_features",
134
+ in_axes=(0, 0, None, None, None, None, None, None, None),
135
+ static_broadcasted_argnums=(
136
+ 3,
137
+ 4,
138
+ 5,
139
+ 6,
140
+ 7,
141
+ 8,
142
+ ),
143
+ )
144
+
145
+ def generate(
146
+ self,
147
+ input_features,
148
+ language=None,
149
+ task=None,
150
+ return_timestamps=False,
151
+ num_beams=1,
152
+ length_penalty=1.0,
153
+ do_sample=False,
154
+ top_k=50,
155
+ temperature=1.0,
156
+ ):
157
+ forced_decoder_ids = self.get_forced_decoder_ids(
158
+ language=language, task=task, return_timestamps=return_timestamps
159
+ )
160
+ # if we're using pmap we need to manually replicate the input data across devices and gather the output tokens
161
+ output_ids = self.p_generate(
162
+ freeze(self.params),
163
+ shard(input_features),
164
+ forced_decoder_ids,
165
+ return_timestamps,
166
+ num_beams,
167
+ length_penalty,
168
+ do_sample,
169
+ top_k,
170
+ temperature,
171
+ ).sequences
172
+ output_ids = jax.device_get(output_ids.reshape(-1, self.max_length))
173
+ return output_ids
174
+
175
+ def get_forced_decoder_ids(self, generation_config=None, task=None, language=None, return_timestamps=False):
176
+ if generation_config is None:
177
+ generation_config = self.model.generation_config
178
+
179
+ if hasattr(generation_config, "is_multilingual"):
180
+ is_multilingual = generation_config.is_multilingual
181
+ else:
182
+ is_multilingual = None
183
+
184
+ forced_decoder_ids = []
185
+
186
+ if is_multilingual:
187
+ if language is not None:
188
+ language = language.lower()
189
+ if language in generation_config.lang_to_id.keys():
190
+ language_token = language
191
+ elif language in TO_LANGUAGE_CODE.values():
192
+ language_token = f"<|{language}|>"
193
+ elif language in TO_LANGUAGE_CODE.keys():
194
+ language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
195
+ else:
196
+ if len(language) == 2:
197
+ # ISO 639-1 language code
198
+ acceptable_languages = list(TO_LANGUAGE_CODE.values())
199
+ elif "<" in language or "|" in language or ">" in language:
200
+ # generation config language code
201
+ acceptable_languages = list(generation_config.lang_to_id.keys())
202
+ else:
203
+ # language passed as a string
204
+ acceptable_languages = list(TO_LANGUAGE_CODE.keys())
205
+ raise ValueError(
206
+ f"Unsupported language: {language}. Language should be one of:" f" {acceptable_languages}."
207
+ )
208
+ forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
209
+
210
+ if task is not None:
211
+ forced_decoder_ids.append((2, generation_config.task_to_id[task]))
212
+ else:
213
+ forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
214
+
215
+ if not return_timestamps:
216
+ if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
217
+ idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
218
+ forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
219
+ else:
220
+ forced_decoder_ids.append((1, generation_config.no_timestamps_token_id))
221
+
222
+ return forced_decoder_ids
223
+
224
+ def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
225
+ inputs_len = inputs.shape[0]
226
+ step = chunk_len - stride_left - stride_right
227
+
228
+ all_chunk_start_idx = np.arange(0, inputs_len, step)
229
+ num_samples = len(all_chunk_start_idx)
230
+
231
+ num_batches = math.ceil(num_samples / batch_size)
232
+ batch_idx = np.array_split(np.arange(num_samples), num_batches)
233
+
234
+ for idx in batch_idx:
235
+ chunk_start_idx = all_chunk_start_idx[idx]
236
+
237
+ chunk_end_idx = chunk_start_idx + chunk_len
238
+
239
+ chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
240
+ processed = self.feature_extractor(
241
+ chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
242
+ )
243
+
244
+ _stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
245
+ is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
246
+ _stride_right = np.where(is_last, 0, stride_right)
247
+
248
+ chunk_lens = [chunk.shape[0] for chunk in chunks]
249
+ strides = [
250
+ (chunk_l, _stride_l, _stride_r)
251
+ for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
252
+ ]
253
+
254
+ yield {"stride": strides, **processed}
255
+
256
+ def preprocess_batch(self, inputs, chunk_length_s=30.0, stride_length_s=None, batch_size=None):
257
+ if isinstance(inputs, np.ndarray):
258
+ logger.warning(
259
+ "Numpy array passed as input - no sampling rate checks will be performed."
260
+ "It is strongly recommended to pass the input as a dictionary with an 'array' key "
261
+ "containing the numpy array representing the audio, and a 'sampling_rate' key "
262
+ "containing the sampling rate associated with the audio array."
263
+ "Failing to do so can result in silent errors that might be hard to debug."
264
+ )
265
+
266
+ if isinstance(inputs, str):
267
+ if inputs.startswith("http://") or inputs.startswith("https://"):
268
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
269
+ # like http_huggingface_co.png
270
+ inputs = requests.get(inputs).content
271
+ else:
272
+ with open(inputs, "rb") as f:
273
+ inputs = f.read()
274
+
275
+ if isinstance(inputs, bytes):
276
+ inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
277
+
278
+ stride = None
279
+ if isinstance(inputs, dict):
280
+ stride = inputs.get("stride", None)
281
+ # Accepting `"array"` which is the key defined in `datasets` for
282
+ # better integration
283
+ if not ("sampling_rate" in inputs and "array" in inputs):
284
+ raise ValueError(
285
+ "When passing a dictionary to FlaxWhisperPipline, the dict needs to contain an 'array' key "
286
+ "containing the numpy array representing the audio, and a 'sampling_rate' key "
287
+ "containing the sampling rate associated with the audio array."
288
+ )
289
+
290
+ in_sampling_rate = inputs.get("sampling_rate")
291
+ inputs = inputs.get("array", None)
292
+
293
+ if in_sampling_rate != self.feature_extractor.sampling_rate:
294
+ try:
295
+ import librosa
296
+ except ImportError as err:
297
+ raise ImportError(
298
+ "To support resampling audio files, please install 'librosa' and 'soundfile'."
299
+ ) from err
300
+
301
+ inputs = librosa.resample(
302
+ inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate
303
+ )
304
+ ratio = self.feature_extractor.sampling_rate / in_sampling_rate
305
+ else:
306
+ ratio = 1
307
+
308
+ if not isinstance(inputs, np.ndarray):
309
+ raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
310
+ if len(inputs.shape) != 1:
311
+ raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
312
+
313
+ if stride is not None:
314
+ if stride[0] + stride[1] > inputs.shape[0]:
315
+ raise ValueError("Stride is too large for input")
316
+
317
+ # Stride needs to get the chunk length here, it's going to get
318
+ # swallowed by the `feature_extractor` later, and then batching
319
+ # can add extra data in the inputs, so we need to keep track
320
+ # of the original length in the stride so we can cut properly.
321
+ stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
322
+
323
+ if chunk_length_s:
324
+ if stride_length_s is None:
325
+ stride_length_s = chunk_length_s / 6
326
+
327
+ if isinstance(stride_length_s, (int, float)):
328
+ stride_length_s = [stride_length_s, stride_length_s]
329
+
330
+ chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
331
+ stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
332
+ stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
333
+
334
+ if chunk_len < stride_left + stride_right:
335
+ raise ValueError("Chunk length must be superior to stride length")
336
+
337
+ for item in self.chunk_iter_with_batch(
338
+ inputs,
339
+ chunk_len,
340
+ stride_left,
341
+ stride_right,
342
+ batch_size,
343
+ ):
344
+ yield item
345
+ else:
346
+ processed = self.feature_extractor(
347
+ inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
348
+ )
349
+ if stride is not None:
350
+ processed["stride"] = stride
351
+ yield processed
352
+
353
+ def postprocess(self, model_outputs, return_timestamps=None, return_language=None):
354
+ # unpack the outputs from list(dict(list)) to list(dict)
355
+ model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())]
356
+
357
+ time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
358
+ # Send the chunking back to seconds, it's easier to handle in whisper
359
+ sampling_rate = self.feature_extractor.sampling_rate
360
+ for output in model_outputs:
361
+ if "stride" in output:
362
+ chunk_len, stride_left, stride_right = output["stride"]
363
+ # Go back in seconds
364
+ chunk_len /= sampling_rate
365
+ stride_left /= sampling_rate
366
+ stride_right /= sampling_rate
367
+ output["stride"] = chunk_len, stride_left, stride_right
368
+
369
+ text, optional = self.tokenizer._decode_asr(
370
+ model_outputs,
371
+ return_timestamps=return_timestamps,
372
+ return_language=return_language,
373
+ time_precision=time_precision,
374
+ )
375
+ return {"text": text, **optional}
376
+
377
+ def forward(
378
+ self,
379
+ model_inputs,
380
+ batch_size=None,
381
+ language=None,
382
+ task=None,
383
+ return_timestamps=False,
384
+ num_beams=1,
385
+ length_penalty=1.0,
386
+ do_sample=False,
387
+ top_k=50,
388
+ temperature=1.0,
389
+ ):
390
+ # We need to keep track of some additional input arguments for post-processing so need to forward these on after running generation
391
+ input_features = model_inputs.pop("input_features")
392
+ input_batch_size = input_features.shape[0]
393
+
394
+ if input_batch_size != batch_size:
395
+ padding = np.zeros([batch_size - input_batch_size, *input_features.shape[1:]], input_features.dtype)
396
+ input_features = np.concatenate([input_features, padding])
397
+
398
+ pred_ids = self.generate(
399
+ input_features,
400
+ language=language,
401
+ task=task,
402
+ return_timestamps=return_timestamps,
403
+ num_beams=num_beams,
404
+ length_penalty=length_penalty,
405
+ do_sample=do_sample,
406
+ top_k=top_k,
407
+ temperature=temperature,
408
+ )[:input_batch_size]
409
+
410
+ # tokenizer's decode method expects an extra dim - we insert it here for convenience
411
+ out = {"tokens": pred_ids[:, None, :]}
412
+
413
+ stride = model_inputs.pop("stride", None)
414
+ if stride is not None:
415
+ out["stride"] = stride
416
+
417
+ return out
418
+
419
+ def __call__(
420
+ self,
421
+ inputs,
422
+ chunk_length_s=30.0,
423
+ stride_length_s=None,
424
+ batch_size=None,
425
+ language=None,
426
+ task=None,
427
+ return_timestamps=None,
428
+ num_beams=1,
429
+ length_penalty=1.0,
430
+ do_sample=False,
431
+ top_k=50,
432
+ temperature=1.0,
433
+ ):
434
+ """
435
+ Transcribe an audio input sequence to a text transcription, optionally with timestamps.
436
+
437
+ Args:
438
+ inputs (`np.ndarray` or `bytes` or `str` or `dict`):
439
+ The inputs is either:
440
+ - `str` that is the filename of the audio file, the file will be read at the correct sampling rate
441
+ to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
442
+ - `bytes` is the byte content of an audio file and is interpreted by *ffmpeg* in the
443
+ same way.
444
+ - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
445
+ Raw audio assumed to be at the correct sampling rate (16kHz). Note that no further sampling
446
+ rate check will be done.
447
+ - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
448
+ pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "array":
449
+ np.array}`. Optionally an additional argument `"stride": (left: int, right: int)` can be used to
450
+ ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in
451
+ decoding (but used at inference to provide more context to the model). In general, this additional
452
+ stride argument is not required.
453
+ chunk_length_s (`float`, *optional*, defaults to 30.0):
454
+ The input length for each chunk. If `chunk_length_s = 0` then chunking is disabled. By default, the chunk
455
+ length is set 30.0s, equal to Whisper's context window.
456
+ stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`):
457
+ The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables
458
+ the model to *see* more context and infer letters better than without this context but the pipeline
459
+ discards the stride bits at the end to make the final reconstitution as perfect as possible.
460
+
461
+ <Tip>
462
+
463
+ For more information on how to effectively use `stride_length_s`, refer to the [ASR chunking
464
+ blog post](https://huggingface.co/blog/asr-chunking).
465
+
466
+ </Tip>
467
+ batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
468
+ The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
469
+ a batch size in the `__call__` method will supersede any batch size passed to the `__init__`.
470
+ task (`str`, *optional*):
471
+ Task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`.
472
+ language (`str`, *optional*):
473
+ Language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`.
474
+ Defaults to `None`, meaning the language is automatically inferred from the audio input.
475
+ return_timestamps (*optional*, `bool`):
476
+ Whether to return timestamps in the prediction. Defaults to False. If set to true, the pipeline
477
+ will return two keys in the output dictionary: `"text"` containing the text transcription, and `"chunks"`
478
+ containing the transcription segments chunked by their utterance-level timestamps.
479
+ length_penalty (*optional*, `float`):
480
+ Exponential penalty to the length that is used with beam-based generation. It is applied as an
481
+ exponent to the sequence length, which in turn is used to divide the score of the sequence. Since
482
+ the score is the log likelihood of the sequence (i.e. negative), length_penalty > 1.0 promotes
483
+ longer sequences, while length_penalty < 1.0 encourages shorter sequences.
484
+ do_sample (*optional*, `bool`):
485
+ Whether or not to use sampling ; use greedy decoding otherwise.
486
+ top_k (*optional*, `int`):
487
+ The number of the highest probability vocabulary tokens to keep for top-k-filtering.
488
+ temperature (*optional*, `float`):
489
+ The value used to modulate the next token probabilities if sampling.
490
+
491
+ Return:
492
+ `Dict`: A dictionary with the following keys:
493
+ - **text** (`str` ) -- The recognised text.
494
+ - **chunks** (*optional(, `List[Dict]`)
495
+ When using `return_timestamps`, the `chunks` will become a list containing all the various text
496
+ chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
497
+ "there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
498
+ `"".join(chunk["text"] for chunk in output["chunks"])`.
499
+ """
500
+ batch_size = batch_size if batch_size is not None else self.batch_size
501
+ if batch_size % self.min_batch_size != 0:
502
+ raise ValueError(
503
+ f"Batch size must be a multiple of the number of JAX devices, but got batch size {batch_size} and num devices {self.min_batch_size}."
504
+ )
505
+
506
+ dataloader = self.preprocess_batch(
507
+ inputs, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, batch_size=batch_size
508
+ )
509
+ model_outputs = []
510
+ # iterate over our chunked audio samples
511
+ for batch in dataloader:
512
+ model_outputs.append(
513
+ self.forward(
514
+ batch,
515
+ batch_size=batch_size,
516
+ language=language,
517
+ task=task,
518
+ return_timestamps=return_timestamps,
519
+ num_beams=num_beams,
520
+ length_penalty=length_penalty,
521
+ do_sample=do_sample,
522
+ top_k=top_k,
523
+ temperature=temperature,
524
+ )
525
+ )
526
+ post_processed = self.postprocess(model_outputs, return_timestamps=return_timestamps)
527
+ return post_processed
flax/distil_whisper/train_state.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Mapping, MutableMapping, Optional, Tuple
2
+
3
+ import flax.core
4
+ import flax.serialization
5
+ import flax.struct
6
+ import jax.numpy as jnp
7
+ from flax import traverse_util
8
+ from flax.core import scope as flax_scope
9
+ from flax.linen import partitioning as flax_partitioning
10
+
11
+
12
+ EMPTY_DICT = flax.core.freeze({})
13
+ FrozenDict = flax_scope.FrozenDict
14
+ FrozenVariableDict = flax_scope.FrozenVariableDict
15
+ MutableVariableDict = flax_scope.MutableVariableDict
16
+ VariableDict = flax_scope.VariableDict
17
+
18
+
19
+ def _validate_params_axes(params_axes, params):
20
+ axis_names = flax_partitioning.get_axis_names(params_axes)
21
+ missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set(
22
+ traverse_util.flatten_dict(axis_names, sep="/")
23
+ )
24
+ if missing_params_axes:
25
+ raise ValueError(f"Missing axis names for parameters: {missing_params_axes}")
26
+
27
+
28
+ def _split_variables_and_axes(
29
+ variables_and_axes: FrozenVariableDict,
30
+ ) -> Tuple[FrozenVariableDict, FrozenVariableDict]:
31
+ """Splits `variables_and_axes` into two separate dicts with the same keys."""
32
+ # For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`.
33
+ variables = {}
34
+ axes = {}
35
+ for k, v in variables_and_axes.items():
36
+ if k.endswith("_axes"):
37
+ axes[k[:-5]] = v # k without "_axes".
38
+ _validate_params_axes(v, variables_and_axes[k[:-5]]) # k without "_axes".
39
+ else:
40
+ variables[k] = v
41
+ return flax.core.freeze(variables), flax.core.freeze(axes)
42
+
43
+
44
+ class InferenceState(flax.struct.PyTreeNode):
45
+ """State compatible with FlaxOptimTrainState without optimizer state."""
46
+
47
+ step: jnp.ndarray
48
+ params: flax_scope.FrozenVariableDict
49
+ params_axes: Optional[flax_scope.FrozenVariableDict] = None
50
+ flax_mutables: flax_scope.FrozenDict = EMPTY_DICT
51
+ flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None
52
+
53
+ @classmethod
54
+ def create(cls, model_variables: FrozenVariableDict) -> "InferenceState":
55
+ other_variables, params = model_variables.pop("params")
56
+ if "params_axes" in other_variables:
57
+ other_variables, params_axes = other_variables.pop("params_axes")
58
+ _validate_params_axes(params_axes, params)
59
+ else:
60
+ params_axes = None
61
+
62
+ # Split other_variables into mutables and their corresponding axes.
63
+ flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables)
64
+ flax_mutables_axes = flax_mutables_axes or None
65
+ return InferenceState(
66
+ step=jnp.array(0),
67
+ params=params,
68
+ params_axes=params_axes,
69
+ flax_mutables=flax_mutables,
70
+ flax_mutables_axes=flax_mutables_axes,
71
+ )
72
+
73
+ @property
74
+ def param_states(self) -> FrozenVariableDict:
75
+ """The optimizer states of the parameters as a PyTree."""
76
+ raise NotImplementedError("InferenceState has no optimizer states.")
77
+
78
+ def apply_gradient(self, *args, **kwargs) -> "InferenceState":
79
+ raise NotImplementedError("InferenceState does not support `apply_gradient`.")
80
+
81
+ def state_dict(self) -> MutableMapping[str, Any]:
82
+ state_dict = {
83
+ "target": flax.core.unfreeze(self.params),
84
+ "state": {"step": self.step},
85
+ }
86
+ if self.flax_mutables:
87
+ state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables)
88
+ return state_dict
89
+
90
+ def replace_step(self, step: jnp.ndarray) -> "InferenceState":
91
+ return self.replace(step=step)
92
+
93
+ def replace_params(self, params: FrozenVariableDict) -> "InferenceState":
94
+ return self.replace(params=params)
95
+
96
+ def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState":
97
+ return self.replace(flax_mutables=flax_mutables)
98
+
99
+ def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState":
100
+ return self.replace(
101
+ params=flax.core.freeze(state_dict["target"]),
102
+ step=state_dict["state"]["step"],
103
+ flax_mutables=(
104
+ flax.core.freeze(state_dict["flax_mutables"]) if "flax_mutables" in state_dict else EMPTY_DICT
105
+ ),
106
+ )
107
+
108
+ def as_logical_axes(self) -> "InferenceState":
109
+ # Set step to None so that when the logical axes are processed by the
110
+ # flax.partitioning.logical_to_mesh_axes function, it will be skipped
111
+ # because jax.tree_map will short circut and never call the function on the
112
+ # step.
113
+ flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT
114
+ return InferenceState(
115
+ step=None,
116
+ params=flax_partitioning.get_axis_names(self.params_axes),
117
+ flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes),
118
+ )
flax/distillation_scripts/run_32_2_pt.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ accelerate launch --multi_gpu --mixed_precision=bf16 --num_processes=2 run_distillation_pt.py \
4
+ --model_name_or_path distil-whisper/large-32-2 \
5
+ --teacher_model_name_or_path openai/whisper-large-v2 \
6
+ --train_dataset_config_name all+all+all+l \
7
+ --train_dataset_samples 2.9+10.4+14.9+226.6 \
8
+ --train_dataset_name librispeech_asr+librispeech_asr+librispeech_asr+gigaspeech-l \
9
+ --train_split_name train.clean.100+train.clean.360+train.other.500+train \
10
+ --eval_dataset_name librispeech_asr+librispeech_asr+gigaspeech-l \
11
+ --eval_dataset_config_name all+all+l \
12
+ --eval_split_name validation.clean+validation.other+validation \
13
+ --eval_text_column_name text+text+text \
14
+ --eval_steps 2500 \
15
+ --save_steps 2500 \
16
+ --warmup_steps 50 \
17
+ --learning_rate 0.0001 \
18
+ --lr_scheduler_type constant_with_warmup \
19
+ --logging_steps 25 \
20
+ --save_total_limit 1 \
21
+ --max_steps 10000 \
22
+ --wer_threshold 10 \
23
+ --per_device_train_batch_size 64 \
24
+ --gradient_accumulation_steps 2 \
25
+ --per_device_eval_batch_size 64 \
26
+ --dataloader_num_workers 16 \
27
+ --cache_dir /fsx/sanchit/cache \
28
+ --dataset_cache_dir /fsx/sanchit/cache \
29
+ --dtype bfloat16 \
30
+ --output_dir ./ \
31
+ --wandb_project distil-whisper-training \
32
+ --do_train \
33
+ --do_eval \
34
+ --gradient_checkpointing \
35
+ --overwrite_output_dir \
36
+ --predict_with_generate \
37
+ --freeze_encoder \
38
+ --streaming
flax/distillation_scripts/run_bs_sweep.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ command:
2
+ - python3
3
+ - ${program}
4
+ - --do_train
5
+ - --use_scan
6
+ - --gradient_checkpointing
7
+ - --overwrite_output_dir
8
+ - --predict_with_generate
9
+ - --freeze_encoder
10
+ - --streaming
11
+ - --use_auth_token
12
+ - --compilation_cache
13
+ - ${args}
14
+ method: grid
15
+ metric:
16
+ goal: minimize
17
+ name: train/loss
18
+ parameters:
19
+ model_name_or_path:
20
+ value: distil-whisper/large-32-2
21
+ teacher_model_name_or_path:
22
+ value: openai/whisper-large-v2
23
+ train_dataset_name:
24
+ value: librispeech_asr
25
+ train_dataset_config_name:
26
+ value: all
27
+ train_split_name:
28
+ value: train.other.500
29
+ train_dataset_samples:
30
+ value: 100
31
+ cache_dir:
32
+ value: /fsx/sanchitgandhi/cache
33
+ dataset_cache_dir:
34
+ value: /fsx/sanchitgandhi/cache
35
+ output_dir:
36
+ value: ./
37
+ per_device_train_batch_size:
38
+ values:
39
+ - 128
40
+ - 256
41
+ - 512
42
+ precision:
43
+ values:
44
+ - "full_mixed"
45
+ - "half_mixed"
46
+ dtype:
47
+ value: bfloat16
48
+ do_eval:
49
+ value: false
50
+ learning_rate:
51
+ value: 3e-4
52
+ lr_scheduler_type:
53
+ value: constant_with_warmup
54
+ warmup_steps:
55
+ value: 30
56
+ max_steps:
57
+ value: 30
58
+ save_steps:
59
+ value: 51 # don't save checkpoints during sweep
60
+ dataloader_num_workers:
61
+ value: 48
62
+ logging_steps:
63
+ value: 5
64
+ wer_threshold:
65
+ value: 100
66
+ program: run_distillation.py
67
+ project: distil-whisper-sweeps
flax/distillation_scripts/run_dataset_sweep.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ command:
2
+ - python3
3
+ - ${program}
4
+ - --do_train
5
+ - --do_eval
6
+ - --use_scan
7
+ - --gradient_checkpointing
8
+ - --overwrite_output_dir
9
+ - --predict_with_generate
10
+ - --freeze_encoder
11
+ - --streaming
12
+ - --use_auth_token
13
+ - ${args}
14
+ method: grid
15
+ metric:
16
+ goal: minimize
17
+ name: gigaspeech-l/validation/wer
18
+ parameters:
19
+ model_name_or_path:
20
+ value: distil-whisper/large-32-2
21
+ teacher_model_name_or_path:
22
+ value: openai/whisper-large-v2
23
+ max_train_samples:
24
+ values:
25
+ - 109876
26
+ - 219752
27
+ - 439504
28
+ - 879008
29
+ - 1758015
30
+ - 3516030
31
+ - 7032061
32
+ train_dataset_name:
33
+ value: librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted
34
+ train_dataset_config_name:
35
+ value: all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3
36
+ train_split_name:
37
+ value: train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train
38
+ train_dataset_samples:
39
+ value: 2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8
40
+ eval_dataset_name:
41
+ value: librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs
42
+ eval_dataset_config_name:
43
+ value: all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us
44
+ eval_split_name:
45
+ value: validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation
46
+ eval_text_column_name:
47
+ value: text+text+text+text+text+text+text+text+text+text+text+text+transcription
48
+ cache_dir:
49
+ value: /home/sanchitgandhi/.cache
50
+ dataset_cache_dir:
51
+ value: /home/sanchitgandhi/.cache
52
+ output_dir:
53
+ value: ./
54
+ per_device_train_batch_size:
55
+ value: 64
56
+ per_device_eval_batch_size:
57
+ value: 64
58
+ dtype:
59
+ value: bfloat16
60
+ learning_rate:
61
+ value: 1e-4
62
+ lr_scheduler_type:
63
+ value: constant_with_warmup
64
+ warmup_steps:
65
+ value: 50
66
+ max_steps:
67
+ value: 10000
68
+ save_steps:
69
+ value: 10001 # don't save checkpoints during sweep
70
+ dataloader_num_workers:
71
+ value: 48
72
+ logging_steps:
73
+ value: 25
74
+ wer_threshold:
75
+ value: 10
76
+ program: run_distillation.py
77
+ project: distil-whisper-sweeps
flax/distillation_scripts/run_decoder_sweep.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ command:
2
+ - python3
3
+ - ${program}
4
+ - --do_train
5
+ - --do_eval
6
+ - --use_scan
7
+ - --gradient_checkpointing
8
+ - --overwrite_output_dir
9
+ - --predict_with_generate
10
+ - --freeze_encoder
11
+ - --streaming
12
+ - --use_auth_token
13
+ - ${args}
14
+ method: grid
15
+ metric:
16
+ goal: minimize
17
+ name: gigaspeech-l/validation/wer
18
+ parameters:
19
+ model_name_or_path:
20
+ values:
21
+ - distil-whisper/large-32-16
22
+ - distil-whisper/large-32-8
23
+ - distil-whisper/large-32-4
24
+ - distil-whisper/large-32-2
25
+ teacher_model_name_or_path:
26
+ value: openai/whisper-large-v2
27
+ train_dataset_name:
28
+ value: librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted
29
+ train_dataset_config_name:
30
+ value: all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3
31
+ train_split_name:
32
+ value: train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train
33
+ train_dataset_samples:
34
+ value: 2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8
35
+ eval_dataset_name:
36
+ value: librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech+chime4+google/fleurs
37
+ eval_dataset_config_name:
38
+ value: all+all+en+en+ihm+sdm+clean+release3+all+l+L+1-channel+en_us
39
+ eval_split_name:
40
+ value: validation.clean+validation.other+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation+validation
41
+ eval_text_column_name:
42
+ value: text+text+text+text+text+text+text+text+text+text+text+text+transcription
43
+ cache_dir:
44
+ value: /home/sanchitgandhi/.cache
45
+ dataset_cache_dir:
46
+ value: /home/sanchitgandhi/.cache
47
+ output_dir:
48
+ value: ./
49
+ per_device_train_batch_size:
50
+ value: 64
51
+ per_device_eval_batch_size:
52
+ value: 64
53
+ dtype:
54
+ value: bfloat16
55
+ learning_rate:
56
+ value: 1e-4
57
+ lr_scheduler_type:
58
+ value: constant_with_warmup
59
+ warmup_steps:
60
+ value: 50
61
+ max_steps:
62
+ value: 10000
63
+ save_steps:
64
+ value: 10001 # don't save checkpoints during sweep
65
+ dataloader_num_workers:
66
+ value: 48
67
+ logging_steps:
68
+ value: 25
69
+ wer_threshold:
70
+ value: 10
71
+ program: run_distillation.py
72
+ project: distil-whisper-sweeps
flax/distillation_scripts/run_distillation_12_2_timestamped.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \
4
+ --model_name_or_path "distil-whisper/small-12-2" \
5
+ --teacher_model_name_or_path "openai/whisper-medium.en" \
6
+ --train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+all+all+all+release3" \
7
+ --train_dataset_samples "2.9+10.4+14.9+89+18.2+10.9+10.9+288+26.8+371.2+226.6+2.9+10.4+14.9+26.8" \
8
+ --train_dataset_name "librispeech_asr-timestamped+librispeech_asr-timestamped+librispeech_asr-timestamped+common_voice_13_0-timestamped+voxpopuli-timestamped+ami-ihm-timestamped+ami-sdm-timestamped+peoples_speech-clean-timestamped+tedlium-timestamped+switchboard-data+gigaspeech-l-timestamped+librispeech_asr-prompted+librispeech_asr-prompted+librispeech_asr-prompted+tedlium-prompted" \
9
+ --train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train.clean.100+train.clean.360+train.other.500+train" \
10
+ --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \
11
+ --eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \
12
+ --eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \
13
+ --eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \
14
+ --eval_steps 5000 \
15
+ --save_steps 5000 \
16
+ --warmup_steps 500 \
17
+ --learning_rate 0.0001 \
18
+ --logging_steps 25 \
19
+ --save_total_limit 1 \
20
+ --max_steps 80000 \
21
+ --wer_threshold 10 \
22
+ --per_device_train_batch_size 64 \
23
+ --per_device_eval_batch_size 64 \
24
+ --dtype "bfloat16" \
25
+ --dataloader_num_workers 16 \
26
+ --cache_dir "/home/sanchitgandhi/.cache" \
27
+ --dataset_cache_dir "/home/sanchitgandhi/.cache" \
28
+ --output_dir "./" \
29
+ --timestamp_probability 0.2 \
30
+ --wandb_name "small-12-2-tpu-timestamped-prob-0.2" \
31
+ --wandb_dir "/home/sanchitgandhi/.cache" \
32
+ --wandb_project "distil-whisper" \
33
+ --do_train \
34
+ --do_eval \
35
+ --use_scan \
36
+ --gradient_checkpointing \
37
+ --overwrite_output_dir \
38
+ --predict_with_generate \
39
+ --freeze_encoder \
40
+ --streaming \
41
+ --use_auth_token \
42
+ --push_to_hub
flax/distillation_scripts/run_distillation_15s_context.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \
4
+ --model_name_or_path "distil-whisper/large-32-2-15s-context" \
5
+ --teacher_model_name_or_path "openai/whisper-large-v2" \
6
+ --feature_extractor_name "openai/whisper-large-v2" \
7
+ --train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+L" \
8
+ --train_dataset_samples "100+360+500+2300+450+90+90+12000+450+3600+2500+5000" \
9
+ --train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech" \
10
+ --train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train" \
11
+ --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \
12
+ --eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \
13
+ --eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \
14
+ --eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \
15
+ --eval_steps 5000 \
16
+ --save_steps 5000 \
17
+ --warmup_steps 500 \
18
+ --learning_rate 0.0001 \
19
+ --lr_scheduler_type "linear" \
20
+ --logging_steps 25 \
21
+ --save_total_limit 1 \
22
+ --max_steps 80000 \
23
+ --wer_threshold 10 \
24
+ --per_device_train_batch_size 64 \
25
+ --per_device_eval_batch_size 64 \
26
+ --max_duration_in_seconds 15 \
27
+ --dataloader_num_workers 16 \
28
+ --cache_dir "/home/sanchitgandhi/.cache" \
29
+ --dataset_cache_dir "/home/sanchitgandhi/.cache" \
30
+ --dtype "bfloat16" \
31
+ --output_dir "./" \
32
+ --wandb_name "large-32-2-ts-28k-wer-10-context-15s" \
33
+ --wandb_dir "/home/sanchitgandhi/.cache" \
34
+ --wandb_project "distil-whisper" \
35
+ --do_train \
36
+ --do_eval \
37
+ --use_scan \
38
+ --gradient_checkpointing \
39
+ --overwrite_output_dir \
40
+ --predict_with_generate \
41
+ --streaming \
42
+ --use_auth_token \
43
+ --push_to_hub
flax/distillation_scripts/run_distillation_16_2.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=10000000000 python3 run_distillation.py \
4
+ --model_name_or_path "distil-whisper/large-16-2" \
5
+ --teacher_model_name_or_path "openai/whisper-large-v2" \
6
+ --train_dataset_config_name "all+all+all+en+en+ihm+sdm+clean+release3+all+l+L" \
7
+ --train_dataset_samples "100+360+500+2300+450+90+90+12000+450+3600+2500+5000" \
8
+ --train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr+common_voice_13_0+voxpopuli+ami-ihm+ami-sdm+peoples_speech-clean+tedlium+switchboard-data+gigaspeech-l+spgispeech" \
9
+ --train_split_name "train.clean.100+train.clean.360+train.other.500+train+train+train+train+train+train+train+train+train" \
10
+ --eval_dataset_name "distil-whisper/gigaspeech-l+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset+esb/diagnostic-dataset" \
11
+ --eval_dataset_config_name "l+librispeech+librispeech+common_voice+common_voice+voxpopuli+voxpopuli+tedlium+tedlium+spgispeech+spgispeech+ami+ami" \
12
+ --eval_split_name "validation+clean+other+clean+other+clean+other+clean+other+clean+other+clean+other" \
13
+ --eval_text_column_name "text+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript+ortho_transcript" \
14
+ --eval_steps 5000 \
15
+ --save_steps 5000 \
16
+ --warmup_steps 500 \
17
+ --learning_rate 0.0001 \
18
+ --lr_scheduler_type "linear" \
19
+ --logging_steps 25 \
20
+ --save_total_limit 1 \
21
+ --max_steps 80000 \
22
+ --wer_threshold 10 \
23
+ --per_device_eval_batch_size 64 \
24
+ --per_device_train_batch_size 64 \
25
+ --dataloader_num_workers 16 \
26
+ --cache_dir "/home/sanchitgandhi/.cache" \
27
+ --dataset_cache_dir "/home/sanchitgandhi/.cache" \
28
+ --dtype "bfloat16" \
29
+ --output_dir "./" \
30
+ --wandb_name "large-16-2-ts-28k-wer-10" \
31
+ --wandb_dir "/home/sanchitgandhi/.cache" \
32
+ --wandb_project "distil-whisper" \
33
+ --do_train \
34
+ --do_eval \
35
+ --use_scan \
36
+ --gradient_checkpointing \
37
+ --overwrite_output_dir \
38
+ --predict_with_generate \
39
+ --streaming \
40
+ --use_auth_token \
41
+ --push_to_hub