AndrewMcDowell commited on
Commit
f7f72aa
1 Parent(s): e39fb91

Training in progress, step 1000

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ checkpoint-*/
.ipynb_checkpoints/run_speech_recognition_ctc_bnb-checkpoint.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ """ Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition"""
17
+
18
+ import functools
19
+ import json
20
+ import logging
21
+ import os
22
+ import re
23
+ import sys
24
+ import warnings
25
+ from dataclasses import dataclass, field
26
+ from typing import Dict, List, Optional, Union
27
+
28
+ import datasets
29
+ import numpy as np
30
+ import torch
31
+ from datasets import DatasetDict, load_dataset, load_metric
32
+
33
+ import bitsandbytes as bnb
34
+ import transformers
35
+ from transformers import (
36
+ AutoConfig,
37
+ AutoFeatureExtractor,
38
+ AutoModelForCTC,
39
+ AutoProcessor,
40
+ AutoTokenizer,
41
+ HfArgumentParser,
42
+ Trainer,
43
+ TrainingArguments,
44
+ Wav2Vec2Processor,
45
+ set_seed,
46
+ )
47
+ from transformers.trainer_pt_utils import get_parameter_names
48
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
49
+ from transformers.utils import check_min_version
50
+ from transformers.utils.versions import require_version
51
+
52
+
53
+
54
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
55
+ check_min_version("4.16.0.dev0")
56
+
57
+ require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
58
+
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ def list_field(default=None, metadata=None):
64
+ return field(default_factory=lambda: default, metadata=metadata)
65
+
66
+
67
+ @dataclass
68
+ class ModelArguments:
69
+ """
70
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
71
+ """
72
+
73
+ model_name_or_path: str = field(
74
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
75
+ )
76
+ tokenizer_name_or_path: Optional[str] = field(
77
+ default=None,
78
+ metadata={"help": "Path to pretrained tokenizer or tokenizer identifier from huggingface.co/models"},
79
+ )
80
+ cache_dir: Optional[str] = field(
81
+ default=None,
82
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
83
+ )
84
+ freeze_feature_encoder: bool = field(
85
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
86
+ )
87
+ attention_dropout: float = field(
88
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
89
+ )
90
+ activation_dropout: float = field(
91
+ default=0.0, metadata={"help": "The dropout ratio for activations inside the fully connected layer."}
92
+ )
93
+ feat_proj_dropout: float = field(default=0.0, metadata={"help": "The dropout ratio for the projected features."})
94
+ hidden_dropout: float = field(
95
+ default=0.0,
96
+ metadata={
97
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
98
+ },
99
+ )
100
+ final_dropout: float = field(
101
+ default=0.0,
102
+ metadata={"help": "The dropout probability for the final projection layer."},
103
+ )
104
+ mask_time_prob: float = field(
105
+ default=0.05,
106
+ metadata={
107
+ "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
108
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
109
+ "vectors will be masked along the time axis."
110
+ },
111
+ )
112
+ mask_time_length: int = field(
113
+ default=10,
114
+ metadata={"help": "Length of vector span to mask along the time axis."},
115
+ )
116
+ mask_feature_prob: float = field(
117
+ default=0.0,
118
+ metadata={
119
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
120
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
121
+ },
122
+ )
123
+ mask_feature_length: int = field(
124
+ default=10,
125
+ metadata={"help": "Length of vector span to mask along the feature axis."},
126
+ )
127
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
128
+ ctc_loss_reduction: Optional[str] = field(
129
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
130
+ )
131
+
132
+
133
+ @dataclass
134
+ class DataTrainingArguments:
135
+ """
136
+ Arguments pertaining to what data we are going to input our model for training and eval.
137
+
138
+ Using `HfArgumentParser` we can turn this class
139
+ into argparse arguments to be able to specify them on
140
+ the command line.
141
+ """
142
+
143
+ dataset_name: str = field(
144
+ metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
145
+ )
146
+ dataset_config_name: str = field(
147
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
148
+ )
149
+ train_split_name: str = field(
150
+ default="train+validation",
151
+ metadata={
152
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
153
+ },
154
+ )
155
+ eval_split_name: str = field(
156
+ default="test",
157
+ metadata={
158
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
159
+ },
160
+ )
161
+ audio_column_name: str = field(
162
+ default="audio",
163
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
164
+ )
165
+ text_column_name: str = field(
166
+ default="text",
167
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
168
+ )
169
+ overwrite_cache: bool = field(
170
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
171
+ )
172
+ preprocessing_num_workers: Optional[int] = field(
173
+ default=None,
174
+ metadata={"help": "The number of processes to use for the preprocessing."},
175
+ )
176
+ max_train_samples: Optional[int] = field(
177
+ default=None,
178
+ metadata={
179
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
180
+ "value if set."
181
+ },
182
+ )
183
+ max_eval_samples: Optional[int] = field(
184
+ default=None,
185
+ metadata={
186
+ "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
187
+ "value if set."
188
+ },
189
+ )
190
+ chars_to_ignore: Optional[List[str]] = list_field(
191
+ default=None,
192
+ metadata={"help": "A list of characters to remove from the transcripts."},
193
+ )
194
+ eval_metrics: List[str] = list_field(
195
+ default=["wer"],
196
+ metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"},
197
+ )
198
+ max_duration_in_seconds: float = field(
199
+ default=20.0,
200
+ metadata={
201
+ "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
202
+ },
203
+ )
204
+ min_duration_in_seconds: float = field(
205
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
206
+ )
207
+ preprocessing_only: bool = field(
208
+ default=False,
209
+ metadata={
210
+ "help": "Whether to only do data preprocessing and skip training. "
211
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
212
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
213
+ "so that the cached datasets can consequently be loaded in distributed training"
214
+ },
215
+ )
216
+ use_auth_token: bool = field(
217
+ default=False,
218
+ metadata={
219
+ "help": "If :obj:`True`, will use the token generated when running"
220
+ ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
221
+ },
222
+ )
223
+ unk_token: str = field(
224
+ default="[UNK]",
225
+ metadata={"help": "The unk token for the tokenizer"},
226
+ )
227
+ pad_token: str = field(
228
+ default="[PAD]",
229
+ metadata={"help": "The padding token for the tokenizer"},
230
+ )
231
+ word_delimiter_token: str = field(
232
+ default="|",
233
+ metadata={"help": "The word delimiter token for the tokenizer"},
234
+ )
235
+ phoneme_language: Optional[str] = field(
236
+ default=None,
237
+ metadata={
238
+ "help": "The target language that should be used be"
239
+ " passed to the tokenizer for tokenization. Note that"
240
+ " this is only relevant if the model classifies the"
241
+ " input audio to a sequence of phoneme sequences."
242
+ },
243
+ )
244
+
245
+
246
+ @dataclass
247
+ class DataCollatorCTCWithPadding:
248
+ """
249
+ Data collator that will dynamically pad the inputs received.
250
+ Args:
251
+ processor (:class:`~transformers.AutoProcessor`)
252
+ The processor used for proccessing the data.
253
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
254
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
255
+ among:
256
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
257
+ sequence if provided).
258
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
259
+ maximum acceptable input length for the model if that argument is not provided.
260
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
261
+ different lengths).
262
+ max_length (:obj:`int`, `optional`):
263
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
264
+ max_length_labels (:obj:`int`, `optional`):
265
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
266
+ pad_to_multiple_of (:obj:`int`, `optional`):
267
+ If set will pad the sequence to a multiple of the provided value.
268
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
269
+ 7.5 (Volta).
270
+ """
271
+
272
+ processor: AutoProcessor
273
+ padding: Union[bool, str] = "longest"
274
+ pad_to_multiple_of: Optional[int] = None
275
+ pad_to_multiple_of_labels: Optional[int] = None
276
+
277
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
278
+ # split inputs and labels since they have to be of different lenghts and need
279
+ # different padding methods
280
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
281
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
282
+
283
+ batch = self.processor.pad(
284
+ input_features,
285
+ padding=self.padding,
286
+ pad_to_multiple_of=self.pad_to_multiple_of,
287
+ return_tensors="pt",
288
+ )
289
+
290
+ with self.processor.as_target_processor():
291
+ labels_batch = self.processor.pad(
292
+ label_features,
293
+ padding=self.padding,
294
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
295
+ return_tensors="pt",
296
+ )
297
+
298
+ # replace padding with -100 to ignore loss correctly
299
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
300
+
301
+ batch["labels"] = labels
302
+
303
+ return batch
304
+
305
+
306
+ def create_vocabulary_from_data(
307
+ datasets: DatasetDict,
308
+ word_delimiter_token: Optional[str] = None,
309
+ unk_token: Optional[str] = None,
310
+ pad_token: Optional[str] = None,
311
+ ):
312
+ # Given training and test labels create vocabulary
313
+ def extract_all_chars(batch):
314
+ all_text = " ".join(batch["target_text"])
315
+ vocab = list(set(all_text))
316
+ return {"vocab": [vocab], "all_text": [all_text]}
317
+
318
+ vocabs = datasets.map(
319
+ extract_all_chars,
320
+ batched=True,
321
+ batch_size=-1,
322
+ keep_in_memory=True,
323
+ remove_columns=datasets["train"].column_names,
324
+ )
325
+
326
+ # take union of all unique characters in each dataset
327
+ vocab_set = functools.reduce(
328
+ lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values()
329
+ )
330
+
331
+ vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}
332
+
333
+ # replace white space with delimiter token
334
+ if word_delimiter_token is not None:
335
+ vocab_dict[word_delimiter_token] = vocab_dict[" "]
336
+ del vocab_dict[" "]
337
+
338
+ # add unk and pad token
339
+ if unk_token is not None:
340
+ vocab_dict[unk_token] = len(vocab_dict)
341
+
342
+ if pad_token is not None:
343
+ vocab_dict[pad_token] = len(vocab_dict)
344
+
345
+ return vocab_dict
346
+
347
+
348
+ def main():
349
+ # See all possible arguments in src/transformers/training_args.py
350
+ # or by passing the --help flag to this script.
351
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
352
+
353
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
354
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
355
+ # If we pass only one argument to the script and it's the path to a json file,
356
+ # let's parse it to get our arguments.
357
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
358
+ else:
359
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
360
+
361
+ # Detecting last checkpoint.
362
+ last_checkpoint = None
363
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
364
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
365
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
366
+ raise ValueError(
367
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
368
+ "Use --overwrite_output_dir to overcome."
369
+ )
370
+ elif last_checkpoint is not None:
371
+ logger.info(
372
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
373
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
374
+ )
375
+
376
+ # Setup logging
377
+ logging.basicConfig(
378
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
379
+ datefmt="%m/%d/%Y %H:%M:%S",
380
+ handlers=[logging.StreamHandler(sys.stdout)],
381
+ )
382
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
383
+
384
+ # Log on each process the small summary:
385
+ logger.warning(
386
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
387
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
388
+ )
389
+ # Set the verbosity to info of the Transformers logger (on main process only):
390
+ if is_main_process(training_args.local_rank):
391
+ transformers.utils.logging.set_verbosity_info()
392
+ logger.info("Training/evaluation parameters %s", training_args)
393
+
394
+ # Set seed before initializing model.
395
+ set_seed(training_args.seed)
396
+
397
+ # 1. First, let's load the dataset
398
+ raw_datasets = DatasetDict()
399
+
400
+ if training_args.do_train:
401
+ raw_datasets["train"] = load_dataset(
402
+ data_args.dataset_name,
403
+ data_args.dataset_config_name,
404
+ split=data_args.train_split_name,
405
+ use_auth_token=data_args.use_auth_token,
406
+ )
407
+
408
+ if data_args.audio_column_name not in raw_datasets["train"].column_names:
409
+ raise ValueError(
410
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
411
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
412
+ f"{', '.join(raw_datasets['train'].column_names)}."
413
+ )
414
+
415
+ if data_args.text_column_name not in raw_datasets["train"].column_names:
416
+ raise ValueError(
417
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
418
+ "Make sure to set `--text_column_name` to the correct text column - one of "
419
+ f"{', '.join(raw_datasets['train'].column_names)}."
420
+ )
421
+
422
+ if data_args.max_train_samples is not None:
423
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
424
+
425
+ if training_args.do_eval:
426
+ raw_datasets["eval"] = load_dataset(
427
+ data_args.dataset_name,
428
+ data_args.dataset_config_name,
429
+ split=data_args.eval_split_name,
430
+ use_auth_token=data_args.use_auth_token,
431
+ )
432
+
433
+ if data_args.max_eval_samples is not None:
434
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
435
+
436
+ # 2. We remove some special characters from the datasets
437
+ # that make training complicated and do not help in transcribing the speech
438
+ # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
439
+ # that could be easily picked up by the model
440
+ from pykakasi import kakasi
441
+
442
+ kakasi = kakasi()
443
+ kakasi.setMode('J', 'H') #Convert from kanji to hiragana
444
+ # kakasi.setMode("K", "H") #Convert from katakana to hiragana
445
+ conv = kakasi.getConverter()
446
+
447
+ chars_to_ignore_regex = (
448
+ f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else '[\,\?\!\-\;\:\"\“\%\‘\”\�\—\’\…\–\(\,\[\]\)\(\!]'
449
+ )
450
+ text_column_name = data_args.text_column_name
451
+
452
+
453
+
454
+ def remove_special_characters(batch):
455
+ if chars_to_ignore_regex is not None:
456
+ batch["target_text"] = conv.do(re.sub(chars_to_ignore_regex, "", batch[text_column_name])) + " "
457
+ else:
458
+ batch["target_text"] = batch[text_column_name].lower() + " "
459
+ return batch
460
+
461
+ with training_args.main_process_first(desc="dataset map special characters removal"):
462
+ raw_datasets = raw_datasets.map(
463
+ remove_special_characters,
464
+ remove_columns=[text_column_name],
465
+ desc="remove special characters from datasets",
466
+ )
467
+
468
+ # save special tokens for tokenizer
469
+ word_delimiter_token = data_args.word_delimiter_token
470
+ unk_token = data_args.unk_token
471
+ pad_token = data_args.pad_token
472
+
473
+ # 3. Next, let's load the config as we might need it to create
474
+ # the tokenizer
475
+ # load config
476
+ config = AutoConfig.from_pretrained(
477
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
478
+ )
479
+
480
+ # 4. Next, if no tokenizer file is defined,
481
+ # we create the vocabulary of the model by extracting all unique characters from
482
+ # the training and evaluation datasets
483
+ # We need to make sure that only first rank saves vocabulary
484
+ # make sure all processes wait until vocab is created
485
+ tokenizer_name_or_path = model_args.tokenizer_name_or_path
486
+ tokenizer_kwargs = {}
487
+ if tokenizer_name_or_path is None:
488
+ # save vocab in training output dir
489
+ tokenizer_name_or_path = training_args.output_dir
490
+
491
+ vocab_file = os.path.join(tokenizer_name_or_path, "vocab.json")
492
+
493
+ with training_args.main_process_first():
494
+ if training_args.overwrite_output_dir and os.path.isfile(vocab_file):
495
+ os.remove(vocab_file)
496
+
497
+ with training_args.main_process_first(desc="dataset map vocabulary creation"):
498
+ if not os.path.isfile(vocab_file):
499
+ os.makedirs(tokenizer_name_or_path, exist_ok=True)
500
+ vocab_dict = create_vocabulary_from_data(
501
+ raw_datasets,
502
+ word_delimiter_token=word_delimiter_token,
503
+ unk_token=unk_token,
504
+ pad_token=pad_token,
505
+ )
506
+
507
+ # save vocab dict to be loaded into tokenizer
508
+ with open(vocab_file, "w") as file:
509
+ json.dump(vocab_dict, file)
510
+
511
+ # if tokenizer has just been created
512
+ # it is defined by `tokenizer_class` if present in config else by `model_type`
513
+ tokenizer_kwargs = {
514
+ "config": config if config.tokenizer_class is not None else None,
515
+ "tokenizer_type": config.model_type if config.tokenizer_class is None else None,
516
+ "unk_token": unk_token,
517
+ "pad_token": pad_token,
518
+ "word_delimiter_token": word_delimiter_token,
519
+ }
520
+
521
+ # 5. Now we can instantiate the feature extractor, tokenizer and model
522
+ # Note for distributed training, the .from_pretrained methods guarantee that only
523
+ # one local process can concurrently download model & vocab.
524
+
525
+ # load feature_extractor and tokenizer
526
+ tokenizer = AutoTokenizer.from_pretrained(
527
+ tokenizer_name_or_path,
528
+ use_auth_token=data_args.use_auth_token,
529
+ **tokenizer_kwargs,
530
+ )
531
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
532
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
533
+ )
534
+
535
+ # adapt config
536
+ config.update(
537
+ {
538
+ "feat_proj_dropout": model_args.feat_proj_dropout,
539
+ "attention_dropout": model_args.attention_dropout,
540
+ "hidden_dropout": model_args.hidden_dropout,
541
+ "final_dropout": model_args.final_dropout,
542
+ "mask_time_prob": model_args.mask_time_prob,
543
+ "mask_time_length": model_args.mask_time_length,
544
+ "mask_feature_prob": model_args.mask_feature_prob,
545
+ "mask_feature_length": model_args.mask_feature_length,
546
+ "gradient_checkpointing": training_args.gradient_checkpointing,
547
+ "layerdrop": model_args.layerdrop,
548
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
549
+ "pad_token_id": tokenizer.pad_token_id,
550
+ "vocab_size": len(tokenizer),
551
+ "activation_dropout": model_args.activation_dropout,
552
+ }
553
+ )
554
+
555
+ # create model
556
+ model = AutoModelForCTC.from_pretrained(
557
+ model_args.model_name_or_path,
558
+ cache_dir=model_args.cache_dir,
559
+ config=config,
560
+ use_auth_token=data_args.use_auth_token,
561
+ )
562
+
563
+ # freeze encoder
564
+ if model_args.freeze_feature_encoder:
565
+ model.freeze_feature_encoder()
566
+
567
+ # 6. Now we preprocess the datasets including loading the audio, resampling and normalization
568
+ # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
569
+ # so that we just need to set the correct target sampling rate and normalize the input
570
+ # via the `feature_extractor`
571
+
572
+ # make sure that dataset decodes audio with correct sampling rate
573
+ dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
574
+ if dataset_sampling_rate != feature_extractor.sampling_rate:
575
+ raw_datasets = raw_datasets.cast_column(
576
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
577
+ )
578
+
579
+ # derive max & min input length for sample rate & max duration
580
+ max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
581
+ min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
582
+ audio_column_name = data_args.audio_column_name
583
+ num_workers = data_args.preprocessing_num_workers
584
+
585
+ # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
586
+ phoneme_language = data_args.phoneme_language
587
+
588
+ # Preprocessing the datasets.
589
+ # We need to read the audio files as arrays and tokenize the targets.
590
+ def prepare_dataset(batch):
591
+ # load audio
592
+ sample = batch[audio_column_name]
593
+
594
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
595
+ batch["input_values"] = inputs.input_values[0]
596
+ batch["input_length"] = len(batch["input_values"])
597
+
598
+ # encode targets
599
+ additional_kwargs = {}
600
+ if phoneme_language is not None:
601
+ additional_kwargs["phonemizer_lang"] = phoneme_language
602
+
603
+ batch["labels"] = tokenizer(batch["target_text"], **additional_kwargs).input_ids
604
+ return batch
605
+
606
+ with training_args.main_process_first(desc="dataset map preprocessing"):
607
+ vectorized_datasets = raw_datasets.map(
608
+ prepare_dataset,
609
+ remove_columns=next(iter(raw_datasets.values())).column_names,
610
+ num_proc=num_workers,
611
+ desc="preprocess datasets",
612
+ )
613
+
614
+ def is_audio_in_length_range(length):
615
+ return length > min_input_length and length < max_input_length
616
+
617
+ # filter data that is shorter than min_input_length
618
+ vectorized_datasets = vectorized_datasets.filter(
619
+ is_audio_in_length_range,
620
+ num_proc=num_workers,
621
+ input_columns=["input_length"],
622
+ )
623
+
624
+ # 7. Next, we can prepare the training.
625
+ # Let's use word error rate (WER) as our evaluation metric,
626
+ # instantiate a data collator and the trainer
627
+
628
+ # Define evaluation metrics during training, *i.e.* word error rate, character error rate
629
+ eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
630
+
631
+ # for large datasets it is advised to run the preprocessing on a
632
+ # single machine first with ``args.preprocessing_only`` since there will mostly likely
633
+ # be a timeout when running the script in distributed mode.
634
+ # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
635
+ # cached dataset
636
+ if data_args.preprocessing_only:
637
+ logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
638
+ return
639
+
640
+ def compute_metrics(pred):
641
+ pred_logits = pred.predictions
642
+ pred_ids = np.argmax(pred_logits, axis=-1)
643
+
644
+ pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
645
+
646
+ pred_str = tokenizer.batch_decode(pred_ids)
647
+ # we do not want to group tokens when computing the metrics
648
+ label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
649
+
650
+ metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
651
+
652
+ return metrics
653
+
654
+ # Now save everything to be able to create a single processor later
655
+ if is_main_process(training_args.local_rank):
656
+ # save feature extractor, tokenizer and config
657
+ feature_extractor.save_pretrained(training_args.output_dir)
658
+ tokenizer.save_pretrained(training_args.output_dir)
659
+ config.save_pretrained(training_args.output_dir)
660
+
661
+ try:
662
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
663
+ except (OSError, KeyError):
664
+ warnings.warn(
665
+ "Loading a processor from a feature extractor config that does not"
666
+ " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following "
667
+ " attribute to your `preprocessor_config.json` file to suppress this warning: "
668
+ " `'processor_class': 'Wav2Vec2Processor'`",
669
+ FutureWarning,
670
+ )
671
+ processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)
672
+
673
+ # Instantiate custom data collator
674
+ data_collator = DataCollatorCTCWithPadding(processor=processor)
675
+
676
+ decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm])
677
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
678
+ optimizer_grouped_parameters = [
679
+ {
680
+ "params": [p for n, p in model.named_parameters() if n in decay_parameters],
681
+ "weight_decay": training_args.weight_decay,
682
+ },
683
+ {
684
+ "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
685
+ "weight_decay": 0.0,
686
+ },
687
+ ]
688
+ optimizer = bnb.optim.Adam8bit(
689
+ params=optimizer_grouped_parameters,
690
+ lr=training_args.learning_rate,
691
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
692
+ eps=training_args.adam_epsilon,
693
+ )
694
+
695
+ optimizers = (optimizer, None)
696
+
697
+ # Initialize Trainer
698
+ trainer = Trainer(
699
+ model=model,
700
+ data_collator=data_collator,
701
+ args=training_args,
702
+ compute_metrics=compute_metrics,
703
+ train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
704
+ eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
705
+ tokenizer=feature_extractor,
706
+ optimizers=optimizers,
707
+ )
708
+
709
+ # 8. Finally, we can start training
710
+
711
+ # Training
712
+ if training_args.do_train:
713
+
714
+ # use last checkpoint if exist
715
+ if last_checkpoint is not None:
716
+ checkpoint = last_checkpoint
717
+ elif os.path.isdir(model_args.model_name_or_path):
718
+ checkpoint = model_args.model_name_or_path
719
+ else:
720
+ checkpoint = None
721
+
722
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
723
+ trainer.save_model()
724
+
725
+ metrics = train_result.metrics
726
+ max_train_samples = (
727
+ data_args.max_train_samples
728
+ if data_args.max_train_samples is not None
729
+ else len(vectorized_datasets["train"])
730
+ )
731
+ metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
732
+
733
+ trainer.log_metrics("train", metrics)
734
+ trainer.save_metrics("train", metrics)
735
+ trainer.save_state()
736
+
737
+ # Evaluation
738
+ results = {}
739
+ if training_args.do_eval:
740
+ logger.info("*** Evaluate ***")
741
+ metrics = trainer.evaluate()
742
+ max_eval_samples = (
743
+ data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
744
+ )
745
+ metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
746
+
747
+ trainer.log_metrics("eval", metrics)
748
+ trainer.save_metrics("eval", metrics)
749
+
750
+ # Write model card and (optionally) push to hub
751
+ config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
752
+ kwargs = {
753
+ "finetuned_from": model_args.model_name_or_path,
754
+ "tasks": "speech-recognition",
755
+ "tags": ["automatic-speech-recognition", data_args.dataset_name],
756
+ "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
757
+ "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
758
+ }
759
+ if "common_voice" in data_args.dataset_name:
760
+ kwargs["language"] = config_name
761
+
762
+ if training_args.push_to_hub:
763
+ trainer.push_to_hub(**kwargs)
764
+ else:
765
+ trainer.create_model_card(**kwargs)
766
+
767
+ return results
768
+
769
+
770
+ if __name__ == "__main__":
771
+ main()
.ipynb_checkpoints/run_training-checkpoint.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_speech_recognition_ctc_bnb.py \
2
+ --dataset_name="mozilla-foundation/common_voice_8_0" \
3
+ --model_name_or_path="facebook/wav2vec2-xls-r-300m" \
4
+ --dataset_config_name="ja" \
5
+ --output_dir="./" \
6
+ --overwrite_output_dir \
7
+ --num_train_epochs="10" \
8
+ --per_device_train_batch_size="48" \
9
+ --per_device_eval_batch_size="8" \
10
+ --learning_rate="7.5e-5" \
11
+ --warmup_steps="2000" \
12
+ --length_column_name="input_length" \
13
+ --evaluation_strategy="steps" \
14
+ --text_column_name="sentence" \
15
+ --save_steps="1000" \
16
+ --eval_steps="1000" \
17
+ --logging_steps="100" \
18
+ --layerdrop="0.0" \
19
+ --activation_dropout="0.1" \
20
+ --save_total_limit="4" \
21
+ --freeze_feature_encoder \
22
+ --feat_proj_dropout="0.0" \
23
+ --mask_time_prob="0.75" \
24
+ --mask_time_length="10" \
25
+ --mask_feature_prob="0.25" \
26
+ --mask_feature_length="64" \
27
+ --gradient_checkpointing \
28
+ --use_auth_token \
29
+ --fp16 \
30
+ --group_by_length \
31
+ --do_train --do_eval \
32
+ --push_to_hub
.ipynb_checkpoints/speech_training_notebook-checkpoint.ipynb ADDED
@@ -0,0 +1,1490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "application/vnd.jupyter.widget-view+json": {
11
+ "model_id": "b7523cd66cf343f98fd3006be918a3b6",
12
+ "version_major": 2,
13
+ "version_minor": 0
14
+ },
15
+ "text/plain": [
16
+ "Downloading: 0%| | 0.00/10.1k [00:00<?, ?B/s]"
17
+ ]
18
+ },
19
+ "metadata": {},
20
+ "output_type": "display_data"
21
+ },
22
+ {
23
+ "data": {
24
+ "application/vnd.jupyter.widget-view+json": {
25
+ "model_id": "251cac7b8968405eafd54e2d29165b40",
26
+ "version_major": 2,
27
+ "version_minor": 0
28
+ },
29
+ "text/plain": [
30
+ "Downloading: 0%| | 0.00/2.98k [00:00<?, ?B/s]"
31
+ ]
32
+ },
33
+ "metadata": {},
34
+ "output_type": "display_data"
35
+ },
36
+ {
37
+ "data": {
38
+ "application/vnd.jupyter.widget-view+json": {
39
+ "model_id": "528c6a67efea4512b04b06a32156d5b7",
40
+ "version_major": 2,
41
+ "version_minor": 0
42
+ },
43
+ "text/plain": [
44
+ "Downloading: 0%| | 0.00/53.1k [00:00<?, ?B/s]"
45
+ ]
46
+ },
47
+ "metadata": {},
48
+ "output_type": "display_data"
49
+ },
50
+ {
51
+ "name": "stdout",
52
+ "output_type": "stream",
53
+ "text": [
54
+ "Downloading and preparing dataset common_voice/ja to /workspace/.cache/huggingface/datasets/mozilla-foundation___common_voice/ja/8.0.0/b8bc4d453193c06a43269b46cd87f075c70f152ac963b7f28f7a2760c45ec3e8...\n"
55
+ ]
56
+ },
57
+ {
58
+ "data": {
59
+ "application/vnd.jupyter.widget-view+json": {
60
+ "model_id": "6c21c5f782734b3bb3f545cef5b59ee0",
61
+ "version_major": 2,
62
+ "version_minor": 0
63
+ },
64
+ "text/plain": [
65
+ "Downloading: 0%| | 0.00/958M [00:00<?, ?B/s]"
66
+ ]
67
+ },
68
+ "metadata": {},
69
+ "output_type": "display_data"
70
+ },
71
+ {
72
+ "data": {
73
+ "application/vnd.jupyter.widget-view+json": {
74
+ "model_id": "",
75
+ "version_major": 2,
76
+ "version_minor": 0
77
+ },
78
+ "text/plain": [
79
+ "0 examples [00:00, ? examples/s]"
80
+ ]
81
+ },
82
+ "metadata": {},
83
+ "output_type": "display_data"
84
+ },
85
+ {
86
+ "data": {
87
+ "application/vnd.jupyter.widget-view+json": {
88
+ "model_id": "",
89
+ "version_major": 2,
90
+ "version_minor": 0
91
+ },
92
+ "text/plain": [
93
+ "0 examples [00:00, ? examples/s]"
94
+ ]
95
+ },
96
+ "metadata": {},
97
+ "output_type": "display_data"
98
+ },
99
+ {
100
+ "data": {
101
+ "application/vnd.jupyter.widget-view+json": {
102
+ "model_id": "",
103
+ "version_major": 2,
104
+ "version_minor": 0
105
+ },
106
+ "text/plain": [
107
+ "0 examples [00:00, ? examples/s]"
108
+ ]
109
+ },
110
+ "metadata": {},
111
+ "output_type": "display_data"
112
+ },
113
+ {
114
+ "data": {
115
+ "application/vnd.jupyter.widget-view+json": {
116
+ "model_id": "",
117
+ "version_major": 2,
118
+ "version_minor": 0
119
+ },
120
+ "text/plain": [
121
+ "0 examples [00:00, ? examples/s]"
122
+ ]
123
+ },
124
+ "metadata": {},
125
+ "output_type": "display_data"
126
+ },
127
+ {
128
+ "data": {
129
+ "application/vnd.jupyter.widget-view+json": {
130
+ "model_id": "",
131
+ "version_major": 2,
132
+ "version_minor": 0
133
+ },
134
+ "text/plain": [
135
+ "0 examples [00:00, ? examples/s]"
136
+ ]
137
+ },
138
+ "metadata": {},
139
+ "output_type": "display_data"
140
+ },
141
+ {
142
+ "name": "stdout",
143
+ "output_type": "stream",
144
+ "text": [
145
+ "Dataset common_voice downloaded and prepared to /workspace/.cache/huggingface/datasets/mozilla-foundation___common_voice/ja/8.0.0/b8bc4d453193c06a43269b46cd87f075c70f152ac963b7f28f7a2760c45ec3e8. Subsequent calls will reuse this data.\n"
146
+ ]
147
+ },
148
+ {
149
+ "name": "stderr",
150
+ "output_type": "stream",
151
+ "text": [
152
+ "Reusing dataset common_voice (/workspace/.cache/huggingface/datasets/mozilla-foundation___common_voice/ja/8.0.0/b8bc4d453193c06a43269b46cd87f075c70f152ac963b7f28f7a2760c45ec3e8)\n"
153
+ ]
154
+ },
155
+ {
156
+ "name": "stdout",
157
+ "output_type": "stream",
158
+ "text": [
159
+ "10623\n"
160
+ ]
161
+ }
162
+ ],
163
+ "source": [
164
+ "from datasets import Audio, Dataset, load_dataset, load_metric\n",
165
+ "from transformers import AutoFeatureExtractor, pipeline\n",
166
+ "\n",
167
+ "language_code = \"ja\"\n",
168
+ "dataset_name = \"mozilla-foundation/common_voice_8_0\"\n",
169
+ "\n",
170
+ "common_voice_train = load_dataset(dataset_name, language_code, use_auth_token=True, split=\"train+validation\")\n",
171
+ "common_voice_test = load_dataset(dataset_name, language_code, use_auth_token=True, split=\"test\")\n",
172
+ "\n",
173
+ "\n",
174
+ "print(len(common_voice_train))\n",
175
+ "\n",
176
+ "# # for testing: only process the first two examples as a test\n",
177
+ "# dataset = dataset.select(range(10))\n",
178
+ "\n",
179
+ "\n"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": 1,
185
+ "metadata": {},
186
+ "outputs": [
187
+ {
188
+ "name": "stdout",
189
+ "output_type": "stream",
190
+ "text": [
191
+ "Collecting pykakasi\n",
192
+ " Downloading pykakasi-2.2.1-py3-none-any.whl (2.4 MB)\n",
193
+ " |████████████████████████████████| 2.4 MB 9.9 MB/s \n",
194
+ "\u001b[?25hCollecting jaconv\n",
195
+ " Downloading jaconv-0.3.tar.gz (15 kB)\n",
196
+ " Preparing metadata (setup.py) ... \u001b[?25ldone\n",
197
+ "\u001b[?25hCollecting deprecated\n",
198
+ " Downloading Deprecated-1.2.13-py2.py3-none-any.whl (9.6 kB)\n",
199
+ "Collecting wrapt<2,>=1.10\n",
200
+ " Downloading wrapt-1.13.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (84 kB)\n",
201
+ " |████████████████████████████████| 84 kB 12.8 MB/s \n",
202
+ "\u001b[?25hBuilding wheels for collected packages: jaconv\n",
203
+ " Building wheel for jaconv (setup.py) ... \u001b[?25ldone\n",
204
+ "\u001b[?25h Created wheel for jaconv: filename=jaconv-0.3-py3-none-any.whl size=15553 sha256=fd764f215e4d567cb60062a7052497b66729e9e2190e2e00153e0d19734088e7\n",
205
+ " Stored in directory: /workspace/.cache/pip/wheels/73/e8/fb/b4ad8117719f79ac73bc05406d1768f845688cdbeed7aad87e\n",
206
+ "Successfully built jaconv\n",
207
+ "Installing collected packages: wrapt, jaconv, deprecated, pykakasi\n",
208
+ "Successfully installed deprecated-1.2.13 jaconv-0.3 pykakasi-2.2.1 wrapt-1.13.3\n",
209
+ "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.0.2 is available.\n",
210
+ "You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
211
+ ]
212
+ }
213
+ ],
214
+ "source": [
215
+ "!pip install pykakasi"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": 4,
221
+ "metadata": {},
222
+ "outputs": [
223
+ {
224
+ "name": "stdout",
225
+ "output_type": "stream",
226
+ "text": [
227
+ "にんじゃ ひらがな kana\n"
228
+ ]
229
+ },
230
+ {
231
+ "name": "stderr",
232
+ "output_type": "stream",
233
+ "text": [
234
+ "/tmp/ipykernel_2159/3076271513.py:4: DeprecationWarning: Call to deprecated method setMode. (Old API will be removed in v3.0.) -- Deprecated since version 2.1.\n",
235
+ " kakasi.setMode('J', 'H') #Convert from kanji to hiragana\n",
236
+ "/tmp/ipykernel_2159/3076271513.py:6: DeprecationWarning: Call to deprecated method getConverter. (Old API will be removed in v3.0.) -- Deprecated since version 2.1.\n",
237
+ " conv = kakasi.getConverter()\n",
238
+ "/tmp/ipykernel_2159/3076271513.py:10: DeprecationWarning: Call to deprecated method do. (Old API will be removed in v3.0.) -- Deprecated since version 2.1.\n",
239
+ " print(conv.do(str))\n"
240
+ ]
241
+ }
242
+ ],
243
+ "source": [
244
+ "from pykakasi import kakasi\n",
245
+ "\n",
246
+ "kakasi = kakasi()\n",
247
+ "kakasi.setMode('J', 'H') #Convert from kanji to hiragana\n",
248
+ "# kakasi.setMode(\"K\", \"H\") #Convert from katakana to hiragana\n",
249
+ "conv = kakasi.getConverter()\n",
250
+ "\n",
251
+ "str = 'にんじゃ 平仮名 kana'\n",
252
+ "\n",
253
+ "print(conv.do(str))"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": 3,
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": [
262
+ "repo_name = 'https://huggingface.co/AndrewMcDowell/wav2vec2-xls-r-1B-german'\n"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": 4,
268
+ "metadata": {},
269
+ "outputs": [],
270
+ "source": [
271
+ "common_voice_train = common_voice_train.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])\n",
272
+ "common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])\n",
273
+ "\n"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": 15,
279
+ "metadata": {},
280
+ "outputs": [
281
+ {
282
+ "data": {
283
+ "application/vnd.jupyter.widget-view+json": {
284
+ "model_id": "ad26c4d7d02948a3bc30d86a0f3527c8",
285
+ "version_major": 2,
286
+ "version_minor": 0
287
+ },
288
+ "text/plain": [
289
+ "0ex [00:00, ?ex/s]"
290
+ ]
291
+ },
292
+ "metadata": {},
293
+ "output_type": "display_data"
294
+ },
295
+ {
296
+ "name": "stderr",
297
+ "output_type": "stream",
298
+ "text": [
299
+ "/tmp/ipykernel_2159/322450745.py:5: DeprecationWarning: Call to deprecated method do. (Old API will be removed in v3.0.) -- Deprecated since version 2.1.\n",
300
+ " batch[\"sentence\"] = conv.do(re.sub(chars_to_remove_regex, '', batch[\"sentence\"]))\n"
301
+ ]
302
+ },
303
+ {
304
+ "data": {
305
+ "application/vnd.jupyter.widget-view+json": {
306
+ "model_id": "93295f1cd50f4557a96ff1bf139c9a37",
307
+ "version_major": 2,
308
+ "version_minor": 0
309
+ },
310
+ "text/plain": [
311
+ "0ex [00:00, ?ex/s]"
312
+ ]
313
+ },
314
+ "metadata": {},
315
+ "output_type": "display_data"
316
+ }
317
+ ],
318
+ "source": [
319
+ "import re\n",
320
+ "chars_to_remove_regex = '[\\,\\?\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�\\—\\’\\…\\–\\(\\,\\[\\]\\)\\(\\!]'\n",
321
+ "# \\.\n",
322
+ "def remove_special_characters(batch):\n",
323
+ " batch[\"sentence\"] = conv.do(re.sub(chars_to_remove_regex, '', batch[\"sentence\"]))\n",
324
+ " return batch\n",
325
+ "\n",
326
+ "common_voice_train = common_voice_train.map(remove_special_characters)\n",
327
+ "common_voice_test = common_voice_test.map(remove_special_characters)"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": 6,
333
+ "metadata": {},
334
+ "outputs": [
335
+ {
336
+ "name": "stdout",
337
+ "output_type": "stream",
338
+ "text": [
339
+ "Collecting num2words\n",
340
+ " Downloading num2words-0.5.10-py3-none-any.whl (101 kB)\n",
341
+ " |████████████████████████████████| 101 kB 7.9 MB/s \n",
342
+ "\u001b[?25hCollecting docopt>=0.6.2\n",
343
+ " Downloading docopt-0.6.2.tar.gz (25 kB)\n",
344
+ " Preparing metadata (setup.py) ... \u001b[?25ldone\n",
345
+ "\u001b[?25hBuilding wheels for collected packages: docopt\n",
346
+ " Building wheel for docopt (setup.py) ... \u001b[?25ldone\n",
347
+ "\u001b[?25h Created wheel for docopt: filename=docopt-0.6.2-py2.py3-none-any.whl size=13704 sha256=7cda85e4b3980668714aad8f5d706fb5b189c2804ce1d99ca2380537fdc78031\n",
348
+ " Stored in directory: /workspace/.cache/pip/wheels/56/ea/58/ead137b087d9e326852a851351d1debf4ada529b6ac0ec4e8c\n",
349
+ "Successfully built docopt\n",
350
+ "Installing collected packages: docopt, num2words\n",
351
+ "Successfully installed docopt-0.6.2 num2words-0.5.10\n",
352
+ "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.0.2 is available.\n",
353
+ "You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
354
+ ]
355
+ }
356
+ ],
357
+ "source": [
358
+ "!pip install num2words"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": 7,
364
+ "metadata": {},
365
+ "outputs": [
366
+ {
367
+ "data": {
368
+ "application/vnd.jupyter.widget-view+json": {
369
+ "model_id": "0da8fd9cdae64c1fa80fbcfc412bcf9c",
370
+ "version_major": 2,
371
+ "version_minor": 0
372
+ },
373
+ "text/plain": [
374
+ "0ex [00:00, ?ex/s]"
375
+ ]
376
+ },
377
+ "metadata": {},
378
+ "output_type": "display_data"
379
+ }
380
+ ],
381
+ "source": [
382
+ "\n",
383
+ "from num2words import num2words\n",
384
+ "import regex as re\n",
385
+ "matches = []\n",
386
+ "\n",
387
+ "def replace_numbers(match):\n",
388
+ " match = match.group()\n",
389
+ " matches.append(match)\n",
390
+ " return num2words(match, lang='de')\n",
391
+ "\n",
392
+ "def replace_numbers_in_batch(batch):\n",
393
+ " batch[\"sentence\"] = re.sub(r'\\d+(?:,\\d+)?', replace_numbers, batch[\"sentence\"])\n",
394
+ " return batch\n",
395
+ "\n",
396
+ "common_voice_test_2 = common_voice_test.map(replace_numbers_in_batch)"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": 10,
402
+ "metadata": {},
403
+ "outputs": [
404
+ {
405
+ "data": {
406
+ "application/vnd.jupyter.widget-view+json": {
407
+ "model_id": "54d62ea7a0214b6abc5de1f106b330dc",
408
+ "version_major": 2,
409
+ "version_minor": 0
410
+ },
411
+ "text/plain": [
412
+ "0ex [00:00, ?ex/s]"
413
+ ]
414
+ },
415
+ "metadata": {},
416
+ "output_type": "display_data"
417
+ }
418
+ ],
419
+ "source": [
420
+ "common_voice_train_2 = common_voice_train.map(replace_numbers_in_batch)"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": 11,
426
+ "metadata": {},
427
+ "outputs": [
428
+ {
429
+ "data": {
430
+ "text/plain": [
431
+ "0"
432
+ ]
433
+ },
434
+ "execution_count": 11,
435
+ "metadata": {},
436
+ "output_type": "execute_result"
437
+ }
438
+ ],
439
+ "source": [
440
+ "len(matches)"
441
+ ]
442
+ },
443
+ {
444
+ "cell_type": "code",
445
+ "execution_count": null,
446
+ "metadata": {},
447
+ "outputs": [],
448
+ "source": [
449
+ "# def replace_accented_characters(batch):\n",
450
+ "# accented_string = u'Málaga'\n",
451
+ "# # accented_string is of type 'unicode'\n",
452
+ "# import unidecode\n",
453
+ "# unaccented_string = unidecode.unidecode(accented_string)\n",
454
+ "# batch[\"sentence\"] = re.sub('[â]', 'a', batch[\"sentence\"])\n",
455
+ "# batch[\"sentence\"] = re.sub('[î]', 'i', batch[\"sentence\"])\n",
456
+ "# batch[\"sentence\"] = re.sub('[ô]', 'o', batch[\"sentence\"])\n",
457
+ "# batch[\"sentence\"] = re.sub('[û]', 'u', batch[\"sentence\"])\n",
458
+ "# return batch\n",
459
+ "\n",
460
+ "def strip_accents(batch):\n",
461
+ " return ''.join(c for c in unicodedata.normalize('NFD', batch[\"sentence\"]) if unicodedata.category(c) != 'Mn')\n",
462
+ "\n",
463
+ "common_voice_train = common_voice_train.map(replace_accented_characters)\n",
464
+ "common_voice_test = common_voice_test.map(replace_accented_characters)"
465
+ ]
466
+ },
467
+ {
468
+ "cell_type": "code",
469
+ "execution_count": null,
470
+ "metadata": {},
471
+ "outputs": [],
472
+ "source": []
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": 6,
477
+ "metadata": {},
478
+ "outputs": [],
479
+ "source": [
480
+ "def extract_all_chars(batch):\n",
481
+ " all_text = \" \".join(batch[\"sentence\"])\n",
482
+ " vocab = list(set(all_text))\n",
483
+ " return {\"vocab\": [vocab], \"all_text\": [all_text]}"
484
+ ]
485
+ },
486
+ {
487
+ "cell_type": "code",
488
+ "execution_count": null,
489
+ "metadata": {},
490
+ "outputs": [],
491
+ "source": []
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "execution_count": 16,
496
+ "metadata": {},
497
+ "outputs": [
498
+ {
499
+ "data": {
500
+ "application/vnd.jupyter.widget-view+json": {
501
+ "model_id": "c40f4d6b6bb74a56b2c570a3a53d7f4b",
502
+ "version_major": 2,
503
+ "version_minor": 0
504
+ },
505
+ "text/plain": [
506
+ " 0%| | 0/1 [00:00<?, ?ba/s]"
507
+ ]
508
+ },
509
+ "metadata": {},
510
+ "output_type": "display_data"
511
+ },
512
+ {
513
+ "data": {
514
+ "application/vnd.jupyter.widget-view+json": {
515
+ "model_id": "f69b6a3c0b54477ea15c56b02464bacd",
516
+ "version_major": 2,
517
+ "version_minor": 0
518
+ },
519
+ "text/plain": [
520
+ " 0%| | 0/1 [00:00<?, ?ba/s]"
521
+ ]
522
+ },
523
+ "metadata": {},
524
+ "output_type": "display_data"
525
+ }
526
+ ],
527
+ "source": [
528
+ "vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)\n",
529
+ "vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "code",
534
+ "execution_count": 17,
535
+ "metadata": {},
536
+ "outputs": [],
537
+ "source": [
538
+ "vocab_list = list(set(vocab_train[\"vocab\"][0]) | set(vocab_test[\"vocab\"][0]))"
539
+ ]
540
+ },
541
+ {
542
+ "cell_type": "code",
543
+ "execution_count": 19,
544
+ "metadata": {
545
+ "collapsed": true,
546
+ "jupyter": {
547
+ "outputs_hidden": true
548
+ }
549
+ },
550
+ "outputs": [
551
+ {
552
+ "data": {
553
+ "text/plain": [
554
+ "[['ん',\n",
555
+ " 'ン',\n",
556
+ " 'ダ',\n",
557
+ " 'S',\n",
558
+ " 'う',\n",
559
+ " 'た',\n",
560
+ " 'ぽ',\n",
561
+ " 'P',\n",
562
+ " ':',\n",
563
+ " '々',\n",
564
+ " 'か',\n",
565
+ " 'ぞ',\n",
566
+ " 'よ',\n",
567
+ " 'や',\n",
568
+ " 'ヨ',\n",
569
+ " 'ゃ',\n",
570
+ " 'Q',\n",
571
+ " 'N',\n",
572
+ " 'だ',\n",
573
+ " 'を',\n",
574
+ " 'L',\n",
575
+ " 'h',\n",
576
+ " 'F',\n",
577
+ " 'E',\n",
578
+ " 'ピ',\n",
579
+ " 'ち',\n",
580
+ " 'ボ',\n",
581
+ " 'w',\n",
582
+ " 'リ',\n",
583
+ " 'ゲ',\n",
584
+ " 'フ',\n",
585
+ " 'あ',\n",
586
+ " 'ウ',\n",
587
+ " 'め',\n",
588
+ " 'タ',\n",
589
+ " 'ぬ',\n",
590
+ " 'せ',\n",
591
+ " 'て',\n",
592
+ " 'b',\n",
593
+ " '」',\n",
594
+ " 'す',\n",
595
+ " 'び',\n",
596
+ " 'ば',\n",
597
+ " 'ア',\n",
598
+ " 'A',\n",
599
+ " 'r',\n",
600
+ " 'ャ',\n",
601
+ " 'イ',\n",
602
+ " 'へ',\n",
603
+ " 'ぶ',\n",
604
+ " 'は',\n",
605
+ " 'u',\n",
606
+ " 'と',\n",
607
+ " '繫',\n",
608
+ " 'ぎ',\n",
609
+ " 'バ',\n",
610
+ " 'ノ',\n",
611
+ " 'I',\n",
612
+ " 'ざ',\n",
613
+ " 'R',\n",
614
+ " 'チ',\n",
615
+ " 'A',\n",
616
+ " '「',\n",
617
+ " 'G',\n",
618
+ " 'ェ',\n",
619
+ " 'く',\n",
620
+ " 'け',\n",
621
+ " 'ぇ',\n",
622
+ " '?',\n",
623
+ " '〜',\n",
624
+ " 'つ',\n",
625
+ " 'わ',\n",
626
+ " 'こ',\n",
627
+ " 'ス',\n",
628
+ " 'ズ',\n",
629
+ " 'p',\n",
630
+ " 'y',\n",
631
+ " 'ぼ',\n",
632
+ " 'し',\n",
633
+ " '、',\n",
634
+ " '!',\n",
635
+ " 'ゼ',\n",
636
+ " 's',\n",
637
+ " 'U',\n",
638
+ " 'き',\n",
639
+ " 'ゥ',\n",
640
+ " '・',\n",
641
+ " 'が',\n",
642
+ " 'も',\n",
643
+ " 'エ',\n",
644
+ " 'ク',\n",
645
+ " 'づ',\n",
646
+ " 'O',\n",
647
+ " 'グ',\n",
648
+ " 'ブ',\n",
649
+ " 'ゅ',\n",
650
+ " 'ィ',\n",
651
+ " 'ぁ',\n",
652
+ " 'd',\n",
653
+ " 't',\n",
654
+ " 'j',\n",
655
+ " 'n',\n",
656
+ " 'ロ',\n",
657
+ " 'g',\n",
658
+ " 'ー',\n",
659
+ " '/',\n",
660
+ " 'ナ',\n",
661
+ " 'ヅ',\n",
662
+ " 'の',\n",
663
+ " 'ケ',\n",
664
+ " 'ほ',\n",
665
+ " '・',\n",
666
+ " ')',\n",
667
+ " 'J',\n",
668
+ " 'D',\n",
669
+ " 'ネ',\n",
670
+ " 'お',\n",
671
+ " 'パ',\n",
672
+ " 'ム',\n",
673
+ " 'む',\n",
674
+ " 'ラ',\n",
675
+ " 'ミ',\n",
676
+ " 'い',\n",
677
+ " 'ろ',\n",
678
+ " 'c',\n",
679
+ " '=',\n",
680
+ " 'z',\n",
681
+ " 'ベ',\n",
682
+ " 'O',\n",
683
+ " 'h',\n",
684
+ " 'プ',\n",
685
+ " 'o',\n",
686
+ " 'ザ',\n",
687
+ " '&',\n",
688
+ " '『',\n",
689
+ " 'ソ',\n",
690
+ " '.',\n",
691
+ " 'ヴ',\n",
692
+ " 'l',\n",
693
+ " 'ド',\n",
694
+ " 'み',\n",
695
+ " 'v',\n",
696
+ " 'x',\n",
697
+ " 'Y',\n",
698
+ " 'ガ',\n",
699
+ " 'に',\n",
700
+ " 'ヌ',\n",
701
+ " 'ら',\n",
702
+ " 'ヘ',\n",
703
+ " 'ょ',\n",
704
+ " 'カ',\n",
705
+ " '。',\n",
706
+ " 'ギ',\n",
707
+ " 'C',\n",
708
+ " 'ぜ',\n",
709
+ " 'モ',\n",
710
+ " 'キ',\n",
711
+ " 'i',\n",
712
+ " 'j',\n",
713
+ " '.',\n",
714
+ " \"'\",\n",
715
+ " 'M',\n",
716
+ " 'ご',\n",
717
+ " 'ど',\n",
718
+ " 'ハ',\n",
719
+ " 'ね',\n",
720
+ " 'で',\n",
721
+ " 'W',\n",
722
+ " 'ぴ',\n",
723
+ " 'T',\n",
724
+ " 'ぷ',\n",
725
+ " ' ',\n",
726
+ " 'マ',\n",
727
+ " '―',\n",
728
+ " 'ビ',\n",
729
+ " 'H',\n",
730
+ " 'デ',\n",
731
+ " 'f',\n",
732
+ " 'ゾ',\n",
733
+ " '-',\n",
734
+ " 'ポ',\n",
735
+ " 'K',\n",
736
+ " 'ヤ',\n",
737
+ " 'ユ',\n",
738
+ " 'シ',\n",
739
+ " 'ペ',\n",
740
+ " 'Z',\n",
741
+ " 'ぱ',\n",
742
+ " 'ふ',\n",
743
+ " 'る',\n",
744
+ " 'べ',\n",
745
+ " 'ヒ',\n",
746
+ " 'e',\n",
747
+ " 'そ',\n",
748
+ " 'テ',\n",
749
+ " 'サ',\n",
750
+ " 'V',\n",
751
+ " 'れ',\n",
752
+ " '」',\n",
753
+ " 'じ',\n",
754
+ " 'ワ',\n",
755
+ " 'レ',\n",
756
+ " 'X',\n",
757
+ " 'ォ',\n",
758
+ " 'ュ',\n",
759
+ " 'ジ',\n",
760
+ " 'k',\n",
761
+ " 'な',\n",
762
+ " 'ニ',\n",
763
+ " 'り',\n",
764
+ " 'q',\n",
765
+ " 'U',\n",
766
+ " 'ひ',\n",
767
+ " 'げ',\n",
768
+ " '&',\n",
769
+ " 'ゆ',\n",
770
+ " 'っ',\n",
771
+ " 'ず',\n",
772
+ " 'ゴ',\n",
773
+ " '「',\n",
774
+ " 'a',\n",
775
+ " 'ぢ',\n",
776
+ " 'ル',\n",
777
+ " 'さ',\n",
778
+ " 'ぺ',\n",
779
+ " 'm',\n",
780
+ " 'ョ',\n",
781
+ " 'ト',\n",
782
+ " 'ツ',\n",
783
+ " 'ホ',\n",
784
+ " 'コ',\n",
785
+ " 'オ',\n",
786
+ " 'セ',\n",
787
+ " 'え',\n",
788
+ " 'ま',\n",
789
+ " 'メ',\n",
790
+ " 'ァ',\n",
791
+ " 'F',\n",
792
+ " 'ぐ',\n",
793
+ " 'B',\n",
794
+ " '』',\n",
795
+ " 'ッ']]"
796
+ ]
797
+ },
798
+ "execution_count": 19,
799
+ "metadata": {},
800
+ "output_type": "execute_result"
801
+ }
802
+ ],
803
+ "source": [
804
+ "# vocab_train[\"vocab\"]"
805
+ ]
806
+ },
807
+ {
808
+ "cell_type": "code",
809
+ "execution_count": 18,
810
+ "metadata": {},
811
+ "outputs": [
812
+ {
813
+ "name": "stdout",
814
+ "output_type": "stream",
815
+ "text": [
816
+ "249\n",
817
+ "['ダ', 'た', 'P', 'か', 'よ', 'や', 'Q', 'を', 'F', 'h', 'E', 'ち', 'リ', 'ゲ', 'フ', 'め', 'タ', 'せ', 'b', '」', 'ば', 'ア', 'A', 'ャ', 'イ', 'ぶ', 'は', 'u', 'と', 'ノ', 'I', 'R', '「', 'G', 'ェ', 'く', '?', '〜', 'つ', 'こ', 'S', 'ぼ', 'ゼ', 's', 'U', 'き', 'ゥ', 'が', 'も', 'エ', 'ク', 'づ', 'グ', 'ブ', 'ゅ', 'ィ', 't', 'n', 'ロ', 'ー', '/', 'の', 'ケ', '・', 'J', 'お', 'む', 'P', 'ベ', 'h', 'プ', 'o', '&', '『', 'ソ', '.', 'ヴ', 'ド', 'み', 'Y', 'ガ', 'ょ', 'カ', 'C', 'ぜ', 'j', '.', 'ご', 'ど', 'ハ', 'ね', 'W', 'j', 'T', ' ', 'マ', '―', '-', 'デ', 'ゾ', 'ポ', 'K', 'ペ', 'ぱ', 'ふ', 'べ', 'ヒ', 'e', 'サ', 'N', 'X', 'ュ', 'k', 'り', 'U', 'ひ', 'げ', 'ゆ', 'ず', 'ゴ', 'a', 'ョ', 'ツ', '〇', 'え', 'F', 'B', '』', 'ッ', 'ん', 'ン', 'S', 'う', 'ぽ', ':', '々', 'ぞ', 'N', 'ヨ', 'ゃ', 'だ', 'L', 'ピ', 'ボ', 'w', 'ウ', 'あ', 'ヶ', 'ぬ', 'て', 'す', 'び', 'r', 'へ', '繫', 'バ', 'ぎ', 'ざ', 'A', 'チ', 'け', 'ぇ', 'わ', 'ス', 'p', 'ズ', 'y', 'し', '、', '!', 'G', '・', 'O', 'ぁ', 'd', 'g', 'ナ', 'ヅ', 'ほ', ')', 'D', 'ネ', 'パ', 'ム', 'ミ', '=', 'z', 'い', 'ろ', 'c', 'O', 'ザ', 'l', 'v', 'x', 'ヌ', 'に', 'ら', 'ヘ', '。', 'ギ', 'モ', 'D', 'キ', 'i', \"'\", 'M', 'で', 'ぴ', 'ぷ', 'ビ', 'H', 'f', 'ヤ', 'ユ', 'シ', 'Z', 'る', 'そ', 'テ', 'V', 'れ', '」', 'じ', 'ワ', 'レ', 'ォ', 'ジ', 'な', 'ニ', 'q', '&', 'っ', '「', 'ぢ', 'ル', 'さ', 'ぺ', 'm', 'ト', 'ホ', 'コ', 'オ', 'セ', 'ま', 'メ', 'ァ', 'ぐ', 'ラ']\n"
818
+ ]
819
+ }
820
+ ],
821
+ "source": [
822
+ "print(len(vocab_list))\n",
823
+ "print(vocab_list)"
824
+ ]
825
+ },
826
+ {
827
+ "cell_type": "code",
828
+ "execution_count": 26,
829
+ "metadata": {},
830
+ "outputs": [],
831
+ "source": [
832
+ "j_vocab = {\"<pad>\": 0, \"<s>\": 1, \"</s>\": 2, \"<unk>\": 3, \"|\": 4, \"'\": 5, \"-\": 6, \"A\": 7, \"B\": 8, \"C\": 9, \"D\": 10, \"E\": 11, \"F\": 12, \"G\": 13, \"H\": 14, \"I\": 15, \"J\": 16, \"K\": 17, \"L\": 18, \"M\": 19, \"N\": 20, \"O\": 21, \"P\": 22, \"Q\": 23, \"R\": 24, \"S\": 25, \"T\": 26, \"U\": 27, \"V\": 28, \"W\": 29, \"X\": 30, \"Y\": 31, \"Z\": 32, \"Ä\": 33, \"Í\": 34, \"Ó\": 35, \"Ö\": 36, \"Ü\": 37}\n"
833
+ ]
834
+ },
835
+ {
836
+ "cell_type": "code",
837
+ "execution_count": 48,
838
+ "metadata": {},
839
+ "outputs": [],
840
+ "source": [
841
+ "manually_kept_values = ['ß', 'ä', 'ö', 'ü']\n",
842
+ "\n",
843
+ "punctuation = ['.', ]"
844
+ ]
845
+ },
846
+ {
847
+ "cell_type": "code",
848
+ "execution_count": 50,
849
+ "metadata": {},
850
+ "outputs": [
851
+ {
852
+ "name": "stdout",
853
+ "output_type": "stream",
854
+ "text": [
855
+ "['$', '&', '(', ')', '*', '+', '.', '/', '=', '@', '[', ']', '_', '`', '¡', '§', '«', '°', '´', 'µ', '·', '»', '×', 'à', 'á', 'â', 'ã', 'å', 'æ', 'ç', 'è', 'é', 'ê', 'ë', 'ì', 'í', 'î', 'ï', 'ð', 'ñ', 'ò', 'ó', 'ô', 'õ', 'ø', 'ù', 'ú', 'û', 'ý', 'þ', 'ā', 'ă', 'ą', 'ć', 'č', 'ď', 'đ', 'ē', 'ė', 'ę', 'ě', 'ğ', 'ġ', 'ħ', 'ī', 'ı', 'ł', 'ń', 'ņ', 'ň', 'ō', 'ŏ', 'ő', 'œ', 'ř', 'ś', 'ş', 'š', 'ť', 'ū', 'ů', 'ź', 'ż', 'ž', 'ơ', 'ǐ', 'ǔ', 'ș', 'ț', 'ə', 'ʻ', 'ʾ', 'ʿ', '̆', '̇', '̥', 'а', 'в', 'е', 'и', 'к', 'м', 'о', 'р', 'с', 'ф', 'ч', 'ш', 'ѹ', 'א', 'ב', 'נ', 'ע', 'ש', '་', 'ན', 'ḫ', 'ṟ', 'ṣ', 'ṭ', 'ạ', 'ả', 'ắ', 'ằ', 'ế', 'ễ', 'ệ', 'ọ', 'ồ', 'ộ', 'ụ', 'ứ', '‑', '‚', '„', '‟', '′', '″', '‹', '›', '→', '−', '≡', '⟨', '⟩', 'カ', '东', '临', '乡', '关', '合', '城', '孙', '尣', '幺', '支', '比', '毛', '泽', '無', '生', '臣', '辶', '道', '镇', '黃']\n"
856
+ ]
857
+ }
858
+ ],
859
+ "source": [
860
+ "odd_values = []\n",
861
+ "for index, value in enumerate(sorted(vocab_list)):\n",
862
+ "# if :\n",
863
+ " if value not in j_vocab and not (16 <= index <= 41 or value == ' ') and value not in manually_kept_values:\n",
864
+ " odd_values.append(value)\n",
865
+ "# print(index, value)\n",
866
+ " \n",
867
+ "print(odd_values)"
868
+ ]
869
+ },
870
+ {
871
+ "cell_type": "code",
872
+ "execution_count": 63,
873
+ "metadata": {},
874
+ "outputs": [
875
+ {
876
+ "name": "stdout",
877
+ "output_type": "stream",
878
+ "text": [
879
+ "$ & ( ) * + . / = @ [ ] _ ` ¡ § « ° ´ µ · » × à á â ã å æ ç è é ê ë ì í î ï ð ñ ò ó ô õ ø ù ú û ý þ ā ă ą ć č ď đ ē ė ę ě ğ ġ ħ ī ı ł ń ņ ň ō ŏ ő œ ř ś ş š ť ū ů ź ż ž ơ ǐ ǔ ș ț ə ʻ ʾ ʿ ̆ ̇ ̥ а в е и к м о р с ф ч ш ѹ א ב נ ע ש ་ ན ḫ ṟ ṣ ṭ ạ ả ắ ằ ế ễ ệ ọ ồ ộ ụ ứ ‑ ‚ „ ‟ ′ ″ ‹ › → − ≡ ⟨ ⟩ カ 东 临 乡 关 合 城 孙 尣 幺 支 比 毛 泽 無 生 臣 辶 道 镇 黃\n"
880
+ ]
881
+ }
882
+ ],
883
+ "source": [
884
+ "print(\" \".join(odd_values))\n",
885
+ "\n",
886
+ "# for value in odd_values:\n",
887
+ "# if value not in manually_kept_values:\n",
888
+ "# print(value)"
889
+ ]
890
+ },
891
+ {
892
+ "cell_type": "code",
893
+ "execution_count": null,
894
+ "metadata": {},
895
+ "outputs": [],
896
+ "source": [
897
+ "$ & ( ) * + = @ [ ] _ ` ¡ § « ° ´ µ · » × à á â ã å æ ç è é ê ë ì í î ï ð ñ ò ó ô õ ø ù ú û ý þ ā ă ą ć č ď đ ē ė ę ě ğ ġ ħ ī ı ł ń ņ ň ō ŏ ő œ ř ś ş š ť ū ů ź ż ž ơ ǐ ǔ ș ț ə ʻ ʾ ʿ ̆ ̇ ̥ а в е и к м о р с ф ч ш ѹ א ב נ ע ש ་ ན ḫ ṟ ṣ ṭ ạ ả ắ ằ ế ễ ệ ọ ồ ộ ụ ứ ‑ ‚ „ ‟ ′ ″ ‹ › → − ≡ ⟨ ⟩ カ 东 临 乡 关 合 城 孙 尣 幺 支 比 毛 泽 無 生 臣 辶 道 镇 黃"
898
+ ]
899
+ },
900
+ {
901
+ "cell_type": "code",
902
+ "execution_count": 54,
903
+ "metadata": {},
904
+ "outputs": [],
905
+ "source": [
906
+ "filtered_vocab_list = [value for value in vocab_list if value not in odd_values]"
907
+ ]
908
+ },
909
+ {
910
+ "cell_type": "code",
911
+ "execution_count": 55,
912
+ "metadata": {},
913
+ "outputs": [
914
+ {
915
+ "data": {
916
+ "text/plain": [
917
+ "['ß',\n",
918
+ " 'j',\n",
919
+ " 'r',\n",
920
+ " 'h',\n",
921
+ " 'd',\n",
922
+ " 'l',\n",
923
+ " 'z',\n",
924
+ " 'n',\n",
925
+ " 'm',\n",
926
+ " 'c',\n",
927
+ " 'ä',\n",
928
+ " \"'\",\n",
929
+ " 'g',\n",
930
+ " 'e',\n",
931
+ " 'w',\n",
932
+ " 's',\n",
933
+ " 'u',\n",
934
+ " 'k',\n",
935
+ " 'o',\n",
936
+ " 'f',\n",
937
+ " ' ',\n",
938
+ " 'y',\n",
939
+ " 'v',\n",
940
+ " 'ö',\n",
941
+ " 'ü',\n",
942
+ " 'p',\n",
943
+ " 'a',\n",
944
+ " 'x',\n",
945
+ " 'b',\n",
946
+ " 'q',\n",
947
+ " 't',\n",
948
+ " 'i']"
949
+ ]
950
+ },
951
+ "execution_count": 55,
952
+ "metadata": {},
953
+ "output_type": "execute_result"
954
+ }
955
+ ],
956
+ "source": [
957
+ "filtered_vocab_list"
958
+ ]
959
+ },
960
+ {
961
+ "cell_type": "code",
962
+ "execution_count": 21,
963
+ "metadata": {},
964
+ "outputs": [
965
+ {
966
+ "ename": "NameError",
967
+ "evalue": "name 'word_delimiter_token' is not defined",
968
+ "output_type": "error",
969
+ "traceback": [
970
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
971
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
972
+ "Input \u001b[0;32mIn [21]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m vocab_dict \u001b[38;5;241m=\u001b[39m {v: k \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28msorted\u001b[39m(vocab_list))}\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# replace white space with delimiter token\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mword_delimiter_token\u001b[49m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 5\u001b[0m vocab_dict[word_delimiter_token] \u001b[38;5;241m=\u001b[39m vocab_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m vocab_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
973
+ "\u001b[0;31mNameError\u001b[0m: name 'word_delimiter_token' is not defined"
974
+ ]
975
+ }
976
+ ],
977
+ "source": [
978
+ "vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}\n",
979
+ "\n",
980
+ "# replace white space with delimiter token\n",
981
+ "if word_delimiter_token is not None:\n",
982
+ " vocab_dict[word_delimiter_token] = vocab_dict[\" \"]\n",
983
+ " del vocab_dict[\" \"]\n",
984
+ "\n",
985
+ "# add unk and pad token\n",
986
+ "if unk_token is not None:\n",
987
+ " vocab_dict[unk_token] = len(vocab_dict)\n",
988
+ "\n",
989
+ "if pad_token is not None:\n",
990
+ " vocab_dict[pad_token] = len(vocab_dict)"
991
+ ]
992
+ },
993
+ {
994
+ "cell_type": "code",
995
+ "execution_count": 58,
996
+ "metadata": {},
997
+ "outputs": [
998
+ {
999
+ "data": {
1000
+ "application/vnd.jupyter.widget-view+json": {
1001
+ "model_id": "59e89471ea85449ebbc709d0a9d7325c",
1002
+ "version_major": 2,
1003
+ "version_minor": 0
1004
+ },
1005
+ "text/plain": [
1006
+ " 0%| | 0/437 [00:00<?, ?ba/s]"
1007
+ ]
1008
+ },
1009
+ "metadata": {},
1010
+ "output_type": "display_data"
1011
+ },
1012
+ {
1013
+ "name": "stdout",
1014
+ "output_type": "stream",
1015
+ "text": [
1016
+ "OOV found in 421223 samples, and they were removed from training set\n",
1017
+ "The final training set size is 14947\n"
1018
+ ]
1019
+ }
1020
+ ],
1021
+ "source": [
1022
+ "vocab_set = set(filtered_vocab_list)\n",
1023
+ "train_dataset_size = len(common_voice_train)\n",
1024
+ "common_voice_train_2 = common_voice_train.filter(\n",
1025
+ " lambda example: vocab_set.issuperset(example[\"sentence\"].replace(\" \", \"\"))\n",
1026
+ ")\n",
1027
+ "print(f\"OOV found in {train_dataset_size - len(common_voice_train_2)} samples, and they were removed from training set\")\n",
1028
+ "print(f\"The final training set size is {len(common_voice_train_2)}\")"
1029
+ ]
1030
+ },
1031
+ {
1032
+ "cell_type": "code",
1033
+ "execution_count": 38,
1034
+ "metadata": {
1035
+ "collapsed": true,
1036
+ "jupyter": {
1037
+ "outputs_hidden": true
1038
+ }
1039
+ },
1040
+ "outputs": [
1041
+ {
1042
+ "ename": "KeyboardInterrupt",
1043
+ "evalue": "",
1044
+ "output_type": "error",
1045
+ "traceback": [
1046
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1047
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
1048
+ "Input \u001b[0;32mIn [38]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m odd_example_texts \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m common_voice_train:\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m letter \u001b[38;5;129;01min\u001b[39;00m odd_values:\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m letter \u001b[38;5;129;01min\u001b[39;00m row[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msentence\u001b[39m\u001b[38;5;124m\"\u001b[39m]: \n",
1049
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/arrow_dataset.py:1664\u001b[0m, in \u001b[0;36mDataset._iter\u001b[0;34m(self, decoded)\u001b[0m\n\u001b[1;32m 1658\u001b[0m \u001b[38;5;124;03m\"\"\"Iterate through the examples.\u001b[39;00m\n\u001b[1;32m 1659\u001b[0m \n\u001b[1;32m 1660\u001b[0m \u001b[38;5;124;03mIf a formatting is set with :meth:`Dataset.set_format` rows will be returned with the\u001b[39;00m\n\u001b[1;32m 1661\u001b[0m \u001b[38;5;124;03mselected format.\u001b[39;00m\n\u001b[1;32m 1662\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1663\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m index \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_rows):\n\u001b[0;32m-> 1664\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_getitem\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1665\u001b[0m \u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1666\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoded\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoded\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1667\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
1050
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/arrow_dataset.py:1915\u001b[0m, in \u001b[0;36mDataset._getitem\u001b[0;34m(self, key, decoded, **kwargs)\u001b[0m\n\u001b[1;32m 1913\u001b[0m formatter \u001b[38;5;241m=\u001b[39m get_formatter(format_type, features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfeatures, decoded\u001b[38;5;241m=\u001b[39mdecoded, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mformat_kwargs)\n\u001b[1;32m 1914\u001b[0m pa_subtable \u001b[38;5;241m=\u001b[39m query_table(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data, key, indices\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_indices \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_indices \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m-> 1915\u001b[0m formatted_output \u001b[38;5;241m=\u001b[39m \u001b[43mformat_table\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1916\u001b[0m \u001b[43m \u001b[49m\u001b[43mpa_subtable\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mformatter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mformatter\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mformat_columns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mformat_columns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_all_columns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_all_columns\u001b[49m\n\u001b[1;32m 1917\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1918\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m formatted_output\n",
1051
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/formatting/formatting.py:533\u001b[0m, in \u001b[0;36mformat_table\u001b[0;34m(table, key, formatter, format_columns, output_all_columns)\u001b[0m\n\u001b[1;32m 531\u001b[0m python_formatter \u001b[38;5;241m=\u001b[39m PythonFormatter(features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 532\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m format_columns \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 533\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mformatter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpa_table\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mquery_type\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 534\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m query_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcolumn\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 535\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m format_columns:\n",
1052
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/formatting/formatting.py:282\u001b[0m, in \u001b[0;36mFormatter.__call__\u001b[0;34m(self, pa_table, query_type)\u001b[0m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, pa_table: pa\u001b[38;5;241m.\u001b[39mTable, query_type: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[RowFormat, ColumnFormat, BatchFormat]:\n\u001b[1;32m 281\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m query_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrow\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 282\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mformat_row\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpa_table\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 283\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m query_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcolumn\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 284\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mformat_column(pa_table)\n",
1053
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/formatting/formatting.py:313\u001b[0m, in \u001b[0;36mPythonFormatter.format_row\u001b[0;34m(self, pa_table)\u001b[0m\n\u001b[1;32m 311\u001b[0m row \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpython_arrow_extractor()\u001b[38;5;241m.\u001b[39mextract_row(pa_table)\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoded:\n\u001b[0;32m--> 313\u001b[0m row \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpython_features_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode_row\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrow\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m row\n",
1054
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/formatting/formatting.py:222\u001b[0m, in \u001b[0;36mPythonFeaturesDecoder.decode_row\u001b[0;34m(self, row)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode_row\u001b[39m(\u001b[38;5;28mself\u001b[39m, row: \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mdict\u001b[39m:\n\u001b[0;32m--> 222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode_example\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrow\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfeatures \u001b[38;5;28;01melse\u001b[39;00m row\n",
1055
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/features/features.py:1318\u001b[0m, in \u001b[0;36mFeatures.decode_example\u001b[0;34m(self, example)\u001b[0m\n\u001b[1;32m 1308\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode_example\u001b[39m(\u001b[38;5;28mself\u001b[39m, example: \u001b[38;5;28mdict\u001b[39m):\n\u001b[1;32m 1309\u001b[0m \u001b[38;5;124;03m\"\"\"Decode example with custom feature decoding.\u001b[39;00m\n\u001b[1;32m 1310\u001b[0m \n\u001b[1;32m 1311\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1315\u001b[0m \u001b[38;5;124;03m :obj:`dict[str, Any]`\u001b[39;00m\n\u001b[1;32m 1316\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1318\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[1;32m 1319\u001b[0m column_name: decode_nested_example(feature, value)\n\u001b[1;32m 1320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_column_requires_decoding[column_name]\n\u001b[1;32m 1321\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m value\n\u001b[1;32m 1322\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m column_name, (feature, value) \u001b[38;5;129;01min\u001b[39;00m utils\u001b[38;5;241m.\u001b[39mzip_dict(\n\u001b[1;32m 1323\u001b[0m {key: value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m example}, example\n\u001b[1;32m 1324\u001b[0m )\n\u001b[1;32m 1325\u001b[0m }\n",
1056
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/features/features.py:1319\u001b[0m, in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 1308\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode_example\u001b[39m(\u001b[38;5;28mself\u001b[39m, example: \u001b[38;5;28mdict\u001b[39m):\n\u001b[1;32m 1309\u001b[0m \u001b[38;5;124;03m\"\"\"Decode example with custom feature decoding.\u001b[39;00m\n\u001b[1;32m 1310\u001b[0m \n\u001b[1;32m 1311\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1315\u001b[0m \u001b[38;5;124;03m :obj:`dict[str, Any]`\u001b[39;00m\n\u001b[1;32m 1316\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m 1318\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[0;32m-> 1319\u001b[0m column_name: \u001b[43mdecode_nested_example\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeature\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_column_requires_decoding[column_name]\n\u001b[1;32m 1321\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m value\n\u001b[1;32m 1322\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m column_name, (feature, value) \u001b[38;5;129;01min\u001b[39;00m utils\u001b[38;5;241m.\u001b[39mzip_dict(\n\u001b[1;32m 1323\u001b[0m {key: value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m example}, example\n\u001b[1;32m 1324\u001b[0m )\n\u001b[1;32m 1325\u001b[0m }\n",
1057
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/features/features.py:1056\u001b[0m, in \u001b[0;36mdecode_nested_example\u001b[0;34m(schema, obj)\u001b[0m\n\u001b[1;32m 1054\u001b[0m \u001b[38;5;66;03m# Object with special decoding:\u001b[39;00m\n\u001b[1;32m 1055\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(schema, (Audio, Image)):\n\u001b[0;32m-> 1056\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mschema\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode_example\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mif\u001b[39;00m obj \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1057\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m obj\n",
1058
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/features/audio.py:97\u001b[0m, in \u001b[0;36mAudio.decode_example\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAn audio sample should have one of \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpath\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m or \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbytes\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m but both are None in \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mvalue\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m path \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m path\u001b[38;5;241m.\u001b[39mendswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmp3\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m---> 97\u001b[0m array, sampling_rate \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_decode_mp3\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfile\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 98\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file:\n",
1059
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/features/audio.py:183\u001b[0m, in \u001b[0;36mAudio._decode_mp3\u001b[0;34m(self, path_or_file)\u001b[0m\n\u001b[1;32m 181\u001b[0m array \u001b[38;5;241m=\u001b[39m array\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmono:\n\u001b[0;32m--> 183\u001b[0m array \u001b[38;5;241m=\u001b[39m \u001b[43marray\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m array, sampling_rate\n",
1060
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/numpy/core/_methods.py:154\u001b[0m, in \u001b[0;36m_mean\u001b[0;34m(a, axis, dtype, out, keepdims)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[38;5;66;03m# Cast bool, unsigned int, and int to float64 by default\u001b[39;00m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 154\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28missubclass\u001b[39m(arr\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;241m.\u001b[39mtype, (\u001b[43mnt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minteger\u001b[49m, nt\u001b[38;5;241m.\u001b[39mbool_)):\n\u001b[1;32m 155\u001b[0m dtype \u001b[38;5;241m=\u001b[39m mu\u001b[38;5;241m.\u001b[39mdtype(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mf8\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28missubclass\u001b[39m(arr\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;241m.\u001b[39mtype, nt\u001b[38;5;241m.\u001b[39mfloat16):\n",
1061
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
1062
+ ]
1063
+ }
1064
+ ],
1065
+ "source": []
1066
+ },
1067
+ {
1068
+ "cell_type": "code",
1069
+ "execution_count": null,
1070
+ "metadata": {},
1071
+ "outputs": [],
1072
+ "source": []
1073
+ },
1074
+ {
1075
+ "cell_type": "code",
1076
+ "execution_count": null,
1077
+ "metadata": {},
1078
+ "outputs": [],
1079
+ "source": []
1080
+ },
1081
+ {
1082
+ "cell_type": "code",
1083
+ "execution_count": 43,
1084
+ "metadata": {},
1085
+ "outputs": [
1086
+ {
1087
+ "data": {
1088
+ "text/plain": [
1089
+ "0"
1090
+ ]
1091
+ },
1092
+ "execution_count": 43,
1093
+ "metadata": {},
1094
+ "output_type": "execute_result"
1095
+ }
1096
+ ],
1097
+ "source": []
1098
+ },
1099
+ {
1100
+ "cell_type": "code",
1101
+ "execution_count": 22,
1102
+ "metadata": {},
1103
+ "outputs": [
1104
+ {
1105
+ "name": "stdout",
1106
+ "output_type": "stream",
1107
+ "text": [
1108
+ "0 \n",
1109
+ "1 &\n",
1110
+ "2 '\n",
1111
+ "3 .\n",
1112
+ "4 /\n",
1113
+ "5 A\n",
1114
+ "6 B\n",
1115
+ "7 C\n",
1116
+ "8 D\n",
1117
+ "9 E\n",
1118
+ "10 F\n",
1119
+ "11 G\n",
1120
+ "12 H\n",
1121
+ "13 I\n",
1122
+ "14 J\n",
1123
+ "15 K\n",
1124
+ "16 L\n",
1125
+ "17 M\n",
1126
+ "18 N\n",
1127
+ "19 O\n",
1128
+ "20 P\n",
1129
+ "21 Q\n",
1130
+ "22 R\n",
1131
+ "23 S\n",
1132
+ "24 T\n",
1133
+ "25 U\n",
1134
+ "26 V\n",
1135
+ "27 W\n",
1136
+ "28 X\n",
1137
+ "29 Y\n",
1138
+ "30 Z\n",
1139
+ "31 a\n",
1140
+ "32 b\n",
1141
+ "33 c\n",
1142
+ "34 d\n",
1143
+ "35 e\n",
1144
+ "36 f\n",
1145
+ "37 g\n",
1146
+ "38 h\n",
1147
+ "39 i\n",
1148
+ "40 j\n",
1149
+ "41 k\n",
1150
+ "42 l\n",
1151
+ "43 m\n",
1152
+ "44 n\n",
1153
+ "45 o\n",
1154
+ "46 p\n",
1155
+ "47 q\n",
1156
+ "48 r\n",
1157
+ "49 s\n",
1158
+ "50 t\n",
1159
+ "51 u\n",
1160
+ "52 v\n",
1161
+ "53 w\n",
1162
+ "54 x\n",
1163
+ "55 y\n",
1164
+ "56 z\n",
1165
+ "57 ―\n",
1166
+ "58 、\n",
1167
+ "59 。\n",
1168
+ "60 々\n",
1169
+ "61 〇\n",
1170
+ "62 「\n",
1171
+ "63 」\n",
1172
+ "64 『\n",
1173
+ "65 』\n",
1174
+ "66 〜\n",
1175
+ "67 ぁ\n",
1176
+ "68 あ\n",
1177
+ "69 い\n",
1178
+ "70 う\n",
1179
+ "71 ぇ\n",
1180
+ "72 え\n",
1181
+ "73 お\n",
1182
+ "74 か\n",
1183
+ "75 が\n",
1184
+ "76 き\n",
1185
+ "77 ぎ\n",
1186
+ "78 く\n",
1187
+ "79 ぐ\n",
1188
+ "80 け\n",
1189
+ "81 げ\n",
1190
+ "82 こ\n",
1191
+ "83 ご\n",
1192
+ "84 さ\n",
1193
+ "85 ざ\n",
1194
+ "86 し\n",
1195
+ "87 じ\n",
1196
+ "88 す\n",
1197
+ "89 ず\n",
1198
+ "90 せ\n",
1199
+ "91 ぜ\n",
1200
+ "92 そ\n",
1201
+ "93 ぞ\n",
1202
+ "94 た\n",
1203
+ "95 だ\n",
1204
+ "96 ち\n",
1205
+ "97 ぢ\n",
1206
+ "98 っ\n",
1207
+ "99 つ\n",
1208
+ "100 づ\n",
1209
+ "101 て\n",
1210
+ "102 で\n",
1211
+ "103 と\n",
1212
+ "104 ど\n",
1213
+ "105 な\n",
1214
+ "106 に\n",
1215
+ "107 ぬ\n",
1216
+ "108 ね\n",
1217
+ "109 の\n",
1218
+ "110 は\n",
1219
+ "111 ば\n",
1220
+ "112 ぱ\n",
1221
+ "113 ひ\n",
1222
+ "114 び\n",
1223
+ "115 ぴ\n",
1224
+ "116 ふ\n",
1225
+ "117 ぶ\n",
1226
+ "118 ぷ\n",
1227
+ "119 へ\n",
1228
+ "120 べ\n",
1229
+ "121 ぺ\n",
1230
+ "122 ほ\n",
1231
+ "123 ぼ\n",
1232
+ "124 ぽ\n",
1233
+ "125 ま\n",
1234
+ "126 み\n",
1235
+ "127 む\n",
1236
+ "128 め\n",
1237
+ "129 も\n",
1238
+ "130 ゃ\n",
1239
+ "131 や\n",
1240
+ "132 ゅ\n",
1241
+ "133 ゆ\n",
1242
+ "134 ょ\n",
1243
+ "135 よ\n",
1244
+ "136 ら\n",
1245
+ "137 り\n",
1246
+ "138 る\n",
1247
+ "139 れ\n",
1248
+ "140 ろ\n",
1249
+ "141 わ\n",
1250
+ "142 を\n",
1251
+ "143 ん\n",
1252
+ "144 ァ\n",
1253
+ "145 ア\n",
1254
+ "146 ィ\n",
1255
+ "147 イ\n",
1256
+ "148 ゥ\n",
1257
+ "149 ウ\n",
1258
+ "150 ェ\n",
1259
+ "151 エ\n",
1260
+ "152 ォ\n",
1261
+ "153 オ\n",
1262
+ "154 カ\n",
1263
+ "155 ガ\n",
1264
+ "156 キ\n",
1265
+ "157 ギ\n",
1266
+ "158 ク\n",
1267
+ "159 グ\n",
1268
+ "160 ケ\n",
1269
+ "161 ゲ\n",
1270
+ "162 コ\n",
1271
+ "163 ゴ\n",
1272
+ "164 サ\n",
1273
+ "165 ザ\n",
1274
+ "166 シ\n",
1275
+ "167 ジ\n",
1276
+ "168 ス\n",
1277
+ "169 ズ\n",
1278
+ "170 セ\n",
1279
+ "171 ゼ\n",
1280
+ "172 ソ\n",
1281
+ "173 ゾ\n",
1282
+ "174 タ\n",
1283
+ "175 ダ\n",
1284
+ "176 チ\n",
1285
+ "177 ッ\n",
1286
+ "178 ツ\n",
1287
+ "179 ヅ\n",
1288
+ "180 テ\n",
1289
+ "181 デ\n",
1290
+ "182 ト\n",
1291
+ "183 ド\n",
1292
+ "184 ナ\n",
1293
+ "185 ニ\n",
1294
+ "186 ヌ\n",
1295
+ "187 ネ\n",
1296
+ "188 ノ\n",
1297
+ "189 ハ\n",
1298
+ "190 バ\n",
1299
+ "191 パ\n",
1300
+ "192 ヒ\n",
1301
+ "193 ビ\n",
1302
+ "194 ピ\n",
1303
+ "195 フ\n",
1304
+ "196 ブ\n",
1305
+ "197 プ\n",
1306
+ "198 ヘ\n",
1307
+ "199 ベ\n",
1308
+ "200 ペ\n",
1309
+ "201 ホ\n",
1310
+ "202 ボ\n",
1311
+ "203 ポ\n",
1312
+ "204 マ\n",
1313
+ "205 ミ\n",
1314
+ "206 ム\n",
1315
+ "207 メ\n",
1316
+ "208 モ\n",
1317
+ "209 ャ\n",
1318
+ "210 ヤ\n",
1319
+ "211 ュ\n",
1320
+ "212 ユ\n",
1321
+ "213 ョ\n",
1322
+ "214 ヨ\n",
1323
+ "215 ラ\n",
1324
+ "216 リ\n",
1325
+ "217 ル\n",
1326
+ "218 レ\n",
1327
+ "219 ロ\n",
1328
+ "220 ワ\n",
1329
+ "221 ン\n",
1330
+ "222 ヴ\n",
1331
+ "223 ヶ\n",
1332
+ "224 ・\n",
1333
+ "225 ー\n",
1334
+ "226 繫\n",
1335
+ "227 !\n",
1336
+ "228 &\n",
1337
+ "229 )\n",
1338
+ "230 -\n",
1339
+ "231 .\n",
1340
+ "232 :\n",
1341
+ "233 =\n",
1342
+ "234 ?\n",
1343
+ "235 A\n",
1344
+ "236 D\n",
1345
+ "237 F\n",
1346
+ "238 G\n",
1347
+ "239 N\n",
1348
+ "240 O\n",
1349
+ "241 P\n",
1350
+ "242 S\n",
1351
+ "243 U\n",
1352
+ "244 h\n",
1353
+ "245 j\n",
1354
+ "246 「\n",
1355
+ "247 」\n",
1356
+ "248 ・\n"
1357
+ ]
1358
+ }
1359
+ ],
1360
+ "source": [
1361
+ "vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}\n",
1362
+ "for key, value in enumerate(vocab_dict):\n",
1363
+ " print(key, value)"
1364
+ ]
1365
+ },
1366
+ {
1367
+ "cell_type": "code",
1368
+ "execution_count": null,
1369
+ "metadata": {},
1370
+ "outputs": [],
1371
+ "source": [
1372
+ "def create_vocabulary_from_data(\n",
1373
+ " datasets: DatasetDict,\n",
1374
+ " word_delimiter_token: Optional[str] = None,\n",
1375
+ " unk_token: Optional[str] = None,\n",
1376
+ " pad_token: Optional[str] = None,\n",
1377
+ "):\n",
1378
+ " # Given training and test labels create vocabulary\n",
1379
+ " def extract_all_chars(batch):\n",
1380
+ " all_text = \" \".join(batch[\"target_text\"])\n",
1381
+ " vocab = list(set(all_text))\n",
1382
+ " return {\"vocab\": [vocab], \"all_text\": [all_text]}\n",
1383
+ "\n",
1384
+ " vocabs = datasets.map(\n",
1385
+ " extract_all_chars,\n",
1386
+ " batched=True,\n",
1387
+ " batch_size=-1,\n",
1388
+ " keep_in_memory=True,\n",
1389
+ " remove_columns=datasets[\"train\"].column_names,\n",
1390
+ " )\n",
1391
+ "\n",
1392
+ " # take union of all unique characters in each dataset\n",
1393
+ " vocab_set = functools.reduce(\n",
1394
+ " lambda vocab_1, vocab_2: set(vocab_1[\"vocab\"][0]) | set(vocab_2[\"vocab\"][0]), vocabs.values()\n",
1395
+ " )\n",
1396
+ "\n",
1397
+ " vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}\n",
1398
+ "\n",
1399
+ " # replace white space with delimiter token\n",
1400
+ " if word_delimiter_token is not None:\n",
1401
+ " vocab_dict[word_delimiter_token] = vocab_dict[\" \"]\n",
1402
+ " del vocab_dict[\" \"]\n",
1403
+ "\n",
1404
+ " # add unk and pad token\n",
1405
+ " if unk_token is not None:\n",
1406
+ " vocab_dict[unk_token] = len(vocab_dict)\n",
1407
+ "\n",
1408
+ " if pad_token is not None:\n",
1409
+ " vocab_dict[pad_token] = len(vocab_dict)\n",
1410
+ "\n",
1411
+ " return vocab_dict"
1412
+ ]
1413
+ },
1414
+ {
1415
+ "cell_type": "code",
1416
+ "execution_count": null,
1417
+ "metadata": {},
1418
+ "outputs": [],
1419
+ "source": []
1420
+ },
1421
+ {
1422
+ "cell_type": "code",
1423
+ "execution_count": null,
1424
+ "metadata": {},
1425
+ "outputs": [],
1426
+ "source": []
1427
+ },
1428
+ {
1429
+ "cell_type": "code",
1430
+ "execution_count": null,
1431
+ "metadata": {},
1432
+ "outputs": [],
1433
+ "source": []
1434
+ },
1435
+ {
1436
+ "cell_type": "code",
1437
+ "execution_count": null,
1438
+ "metadata": {},
1439
+ "outputs": [],
1440
+ "source": [
1441
+ "# load processor\n",
1442
+ "feature_extractor = AutoFeatureExtractor.from_pretrained(repo_name)\n",
1443
+ "# feature_extractor = processor_with_lm.feature_extractor\n",
1444
+ "sampling_rate = feature_extractor.sampling_rate\n",
1445
+ "\n",
1446
+ "# resample audio\n",
1447
+ "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=sampling_rate))\n",
1448
+ "\n",
1449
+ "# load eval pipeline\n",
1450
+ "asr = pipeline(\"automatic-speech-recognition\", model=repo_name, feature_extractor=feature_extractor)\n",
1451
+ "\n",
1452
+ "# map function to decode audio\n",
1453
+ "def map_to_pred(batch):\n",
1454
+ " prediction = asr(\n",
1455
+ " batch[\"audio\"][\"array\"])\n",
1456
+ "\n",
1457
+ " batch[\"prediction\"] = prediction[\"text\"]\n",
1458
+ " batch[\"target\"] = batch[\"sentence\"]\n",
1459
+ " return batch\n",
1460
+ "\n",
1461
+ "# run inference on all examples\n",
1462
+ "result = dataset.map(map_to_pred, remove_columns=dataset.column_names)\n",
1463
+ "print(result[\"prediction\"])\n",
1464
+ "\n",
1465
+ "result[0]['target']"
1466
+ ]
1467
+ }
1468
+ ],
1469
+ "metadata": {
1470
+ "kernelspec": {
1471
+ "display_name": "Python 3 (ipykernel)",
1472
+ "language": "python",
1473
+ "name": "python3"
1474
+ },
1475
+ "language_info": {
1476
+ "codemirror_mode": {
1477
+ "name": "ipython",
1478
+ "version": 3
1479
+ },
1480
+ "file_extension": ".py",
1481
+ "mimetype": "text/x-python",
1482
+ "name": "python",
1483
+ "nbconvert_exporter": "python",
1484
+ "pygments_lexer": "ipython3",
1485
+ "version": "3.8.8"
1486
+ }
1487
+ },
1488
+ "nbformat": 4,
1489
+ "nbformat_minor": 4
1490
+ }
.ipynb_checkpoints/vocab-checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"&": 1, "'": 2, ".": 3, "/": 4, "A": 5, "B": 6, "C": 7, "D": 8, "E": 9, "F": 10, "G": 11, "H": 12, "I": 13, "J": 14, "K": 15, "L": 16, "M": 17, "N": 18, "O": 19, "P": 20, "Q": 21, "R": 22, "S": 23, "T": 24, "U": 25, "V": 26, "W": 27, "X": 28, "Y": 29, "Z": 30, "a": 31, "b": 32, "c": 33, "d": 34, "e": 35, "f": 36, "g": 37, "h": 38, "i": 39, "j": 40, "k": 41, "l": 42, "m": 43, "n": 44, "o": 45, "p": 46, "q": 47, "r": 48, "s": 49, "t": 50, "u": 51, "v": 52, "w": 53, "x": 54, "y": 55, "z": 56, "\u2015": 57, "\u3001": 58, "\u3002": 59, "\u3005": 60, "\u3007": 61, "\u300c": 62, "\u300d": 63, "\u300e": 64, "\u300f": 65, "\u301c": 66, "\u3041": 67, "\u3042": 68, "\u3044": 69, "\u3046": 70, "\u3047": 71, "\u3048": 72, "\u304a": 73, "\u304b": 74, "\u304c": 75, "\u304d": 76, "\u304e": 77, "\u304f": 78, "\u3050": 79, "\u3051": 80, "\u3052": 81, "\u3053": 82, "\u3054": 83, "\u3055": 84, "\u3056": 85, "\u3057": 86, "\u3058": 87, "\u3059": 88, "\u305a": 89, "\u305b": 90, "\u305c": 91, "\u305d": 92, "\u305e": 93, "\u305f": 94, "\u3060": 95, "\u3061": 96, "\u3062": 97, "\u3063": 98, "\u3064": 99, "\u3065": 100, "\u3066": 101, "\u3067": 102, "\u3068": 103, "\u3069": 104, "\u306a": 105, "\u306b": 106, "\u306c": 107, "\u306d": 108, "\u306e": 109, "\u306f": 110, "\u3070": 111, "\u3071": 112, "\u3072": 113, "\u3073": 114, "\u3074": 115, "\u3075": 116, "\u3076": 117, "\u3077": 118, "\u3078": 119, "\u3079": 120, "\u307a": 121, "\u307b": 122, "\u307c": 123, "\u307d": 124, "\u307e": 125, "\u307f": 126, "\u3080": 127, "\u3081": 128, "\u3082": 129, "\u3083": 130, "\u3084": 131, "\u3085": 132, "\u3086": 133, "\u3087": 134, "\u3088": 135, "\u3089": 136, "\u308a": 137, "\u308b": 138, "\u308c": 139, "\u308d": 140, "\u308f": 141, "\u3092": 142, "\u3093": 143, "\u30a1": 144, "\u30a2": 145, "\u30a3": 146, "\u30a4": 147, "\u30a5": 148, "\u30a6": 149, "\u30a7": 150, "\u30a8": 151, "\u30a9": 152, "\u30aa": 153, "\u30ab": 154, "\u30ac": 155, "\u30ad": 156, "\u30ae": 157, "\u30af": 158, "\u30b0": 159, "\u30b1": 160, "\u30b2": 161, "\u30b3": 162, "\u30b4": 163, "\u30b5": 164, "\u30b6": 165, "\u30b7": 166, "\u30b8": 167, "\u30b9": 168, "\u30ba": 169, "\u30bb": 170, "\u30bc": 171, "\u30bd": 172, "\u30be": 173, "\u30bf": 174, "\u30c0": 175, "\u30c1": 176, "\u30c3": 177, "\u30c4": 178, "\u30c5": 179, "\u30c6": 180, "\u30c7": 181, "\u30c8": 182, "\u30c9": 183, "\u30ca": 184, "\u30cb": 185, "\u30cc": 186, "\u30cd": 187, "\u30ce": 188, "\u30cf": 189, "\u30d0": 190, "\u30d1": 191, "\u30d2": 192, "\u30d3": 193, "\u30d4": 194, "\u30d5": 195, "\u30d6": 196, "\u30d7": 197, "\u30d8": 198, "\u30d9": 199, "\u30da": 200, "\u30db": 201, "\u30dc": 202, "\u30dd": 203, "\u30de": 204, "\u30df": 205, "\u30e0": 206, "\u30e1": 207, "\u30e2": 208, "\u30e3": 209, "\u30e4": 210, "\u30e5": 211, "\u30e6": 212, "\u30e7": 213, "\u30e8": 214, "\u30e9": 215, "\u30ea": 216, "\u30eb": 217, "\u30ec": 218, "\u30ed": 219, "\u30ef": 220, "\u30f3": 221, "\u30f4": 222, "\u30f6": 223, "\u30fb": 224, "\u30fc": 225, "\u7e6b": 226, "\uff06": 227, "\uff09": 228, "\uff0d": 229, "\uff0e": 230, "\uff1a": 231, "\uff1d": 232, "\uff1f": 233, "\uff21": 234, "\uff24": 235, "\uff26": 236, "\uff27": 237, "\uff2e": 238, "\uff2f": 239, "\uff30": 240, "\uff33": 241, "\uff35": 242, "\uff48": 243, "\uff4a": 244, "\uff62": 245, "\uff63": 246, "\uff65": 247, "|": 0, "[UNK]": 248, "[PAD]": 249}
added_tokens.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"<s>": 250, "</s>": 251}
config.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/wav2vec2-xls-r-300m",
3
+ "activation_dropout": 0.1,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "Wav2Vec2ForCTC"
10
+ ],
11
+ "attention_dropout": 0.0,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 768,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": true,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "mean",
45
+ "ctc_zero_infinity": false,
46
+ "diversity_loss_weight": 0.1,
47
+ "do_stable_layer_norm": true,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_dropout": 0.0,
51
+ "feat_extract_norm": "layer",
52
+ "feat_proj_dropout": 0.0,
53
+ "feat_quantizer_dropout": 0.0,
54
+ "final_dropout": 0.0,
55
+ "hidden_act": "gelu",
56
+ "hidden_dropout": 0.0,
57
+ "hidden_size": 1024,
58
+ "initializer_range": 0.02,
59
+ "intermediate_size": 4096,
60
+ "layer_norm_eps": 1e-05,
61
+ "layerdrop": 0.0,
62
+ "mask_feature_length": 64,
63
+ "mask_feature_min_masks": 0,
64
+ "mask_feature_prob": 0.25,
65
+ "mask_time_length": 10,
66
+ "mask_time_min_masks": 2,
67
+ "mask_time_prob": 0.75,
68
+ "model_type": "wav2vec2",
69
+ "num_adapter_layers": 3,
70
+ "num_attention_heads": 16,
71
+ "num_codevector_groups": 2,
72
+ "num_codevectors_per_group": 320,
73
+ "num_conv_pos_embedding_groups": 16,
74
+ "num_conv_pos_embeddings": 128,
75
+ "num_feat_extract_layers": 7,
76
+ "num_hidden_layers": 24,
77
+ "num_negatives": 100,
78
+ "output_hidden_size": 1024,
79
+ "pad_token_id": 249,
80
+ "proj_codevector_dim": 768,
81
+ "tdnn_dilation": [
82
+ 1,
83
+ 2,
84
+ 3,
85
+ 1,
86
+ 1
87
+ ],
88
+ "tdnn_dim": [
89
+ 512,
90
+ 512,
91
+ 512,
92
+ 512,
93
+ 1500
94
+ ],
95
+ "tdnn_kernel": [
96
+ 5,
97
+ 3,
98
+ 3,
99
+ 1,
100
+ 1
101
+ ],
102
+ "torch_dtype": "float32",
103
+ "transformers_version": "4.17.0.dev0",
104
+ "use_weighted_layer_sum": false,
105
+ "vocab_size": 252,
106
+ "xvector_output_dim": 512
107
+ }
preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a045ec80cccde6513efafa22639d8feb4ad1eed1931045d55322e78ce00a922
3
+ size 1262956849
run_speech_recognition_ctc_bnb.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ """ Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition"""
17
+
18
+ import functools
19
+ import json
20
+ import logging
21
+ import os
22
+ import re
23
+ import sys
24
+ import warnings
25
+ from dataclasses import dataclass, field
26
+ from typing import Dict, List, Optional, Union
27
+
28
+ import datasets
29
+ import numpy as np
30
+ import torch
31
+ from datasets import DatasetDict, load_dataset, load_metric
32
+
33
+ import bitsandbytes as bnb
34
+ import transformers
35
+ from transformers import (
36
+ AutoConfig,
37
+ AutoFeatureExtractor,
38
+ AutoModelForCTC,
39
+ AutoProcessor,
40
+ AutoTokenizer,
41
+ HfArgumentParser,
42
+ Trainer,
43
+ TrainingArguments,
44
+ Wav2Vec2Processor,
45
+ set_seed,
46
+ )
47
+ from transformers.trainer_pt_utils import get_parameter_names
48
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
49
+ from transformers.utils import check_min_version
50
+ from transformers.utils.versions import require_version
51
+
52
+
53
+
54
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
55
+ check_min_version("4.16.0.dev0")
56
+
57
+ require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
58
+
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ def list_field(default=None, metadata=None):
64
+ return field(default_factory=lambda: default, metadata=metadata)
65
+
66
+
67
+ @dataclass
68
+ class ModelArguments:
69
+ """
70
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
71
+ """
72
+
73
+ model_name_or_path: str = field(
74
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
75
+ )
76
+ tokenizer_name_or_path: Optional[str] = field(
77
+ default=None,
78
+ metadata={"help": "Path to pretrained tokenizer or tokenizer identifier from huggingface.co/models"},
79
+ )
80
+ cache_dir: Optional[str] = field(
81
+ default=None,
82
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
83
+ )
84
+ freeze_feature_encoder: bool = field(
85
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
86
+ )
87
+ attention_dropout: float = field(
88
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
89
+ )
90
+ activation_dropout: float = field(
91
+ default=0.0, metadata={"help": "The dropout ratio for activations inside the fully connected layer."}
92
+ )
93
+ feat_proj_dropout: float = field(default=0.0, metadata={"help": "The dropout ratio for the projected features."})
94
+ hidden_dropout: float = field(
95
+ default=0.0,
96
+ metadata={
97
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
98
+ },
99
+ )
100
+ final_dropout: float = field(
101
+ default=0.0,
102
+ metadata={"help": "The dropout probability for the final projection layer."},
103
+ )
104
+ mask_time_prob: float = field(
105
+ default=0.05,
106
+ metadata={
107
+ "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
108
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
109
+ "vectors will be masked along the time axis."
110
+ },
111
+ )
112
+ mask_time_length: int = field(
113
+ default=10,
114
+ metadata={"help": "Length of vector span to mask along the time axis."},
115
+ )
116
+ mask_feature_prob: float = field(
117
+ default=0.0,
118
+ metadata={
119
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
120
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
121
+ },
122
+ )
123
+ mask_feature_length: int = field(
124
+ default=10,
125
+ metadata={"help": "Length of vector span to mask along the feature axis."},
126
+ )
127
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
128
+ ctc_loss_reduction: Optional[str] = field(
129
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
130
+ )
131
+
132
+
133
+ @dataclass
134
+ class DataTrainingArguments:
135
+ """
136
+ Arguments pertaining to what data we are going to input our model for training and eval.
137
+
138
+ Using `HfArgumentParser` we can turn this class
139
+ into argparse arguments to be able to specify them on
140
+ the command line.
141
+ """
142
+
143
+ dataset_name: str = field(
144
+ metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
145
+ )
146
+ dataset_config_name: str = field(
147
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
148
+ )
149
+ train_split_name: str = field(
150
+ default="train+validation",
151
+ metadata={
152
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
153
+ },
154
+ )
155
+ eval_split_name: str = field(
156
+ default="test",
157
+ metadata={
158
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
159
+ },
160
+ )
161
+ audio_column_name: str = field(
162
+ default="audio",
163
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
164
+ )
165
+ text_column_name: str = field(
166
+ default="text",
167
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
168
+ )
169
+ overwrite_cache: bool = field(
170
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
171
+ )
172
+ preprocessing_num_workers: Optional[int] = field(
173
+ default=None,
174
+ metadata={"help": "The number of processes to use for the preprocessing."},
175
+ )
176
+ max_train_samples: Optional[int] = field(
177
+ default=None,
178
+ metadata={
179
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
180
+ "value if set."
181
+ },
182
+ )
183
+ max_eval_samples: Optional[int] = field(
184
+ default=None,
185
+ metadata={
186
+ "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
187
+ "value if set."
188
+ },
189
+ )
190
+ chars_to_ignore: Optional[List[str]] = list_field(
191
+ default=None,
192
+ metadata={"help": "A list of characters to remove from the transcripts."},
193
+ )
194
+ eval_metrics: List[str] = list_field(
195
+ default=["wer"],
196
+ metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"},
197
+ )
198
+ max_duration_in_seconds: float = field(
199
+ default=20.0,
200
+ metadata={
201
+ "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
202
+ },
203
+ )
204
+ min_duration_in_seconds: float = field(
205
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
206
+ )
207
+ preprocessing_only: bool = field(
208
+ default=False,
209
+ metadata={
210
+ "help": "Whether to only do data preprocessing and skip training. "
211
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
212
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
213
+ "so that the cached datasets can consequently be loaded in distributed training"
214
+ },
215
+ )
216
+ use_auth_token: bool = field(
217
+ default=False,
218
+ metadata={
219
+ "help": "If :obj:`True`, will use the token generated when running"
220
+ ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
221
+ },
222
+ )
223
+ unk_token: str = field(
224
+ default="[UNK]",
225
+ metadata={"help": "The unk token for the tokenizer"},
226
+ )
227
+ pad_token: str = field(
228
+ default="[PAD]",
229
+ metadata={"help": "The padding token for the tokenizer"},
230
+ )
231
+ word_delimiter_token: str = field(
232
+ default="|",
233
+ metadata={"help": "The word delimiter token for the tokenizer"},
234
+ )
235
+ phoneme_language: Optional[str] = field(
236
+ default=None,
237
+ metadata={
238
+ "help": "The target language that should be used be"
239
+ " passed to the tokenizer for tokenization. Note that"
240
+ " this is only relevant if the model classifies the"
241
+ " input audio to a sequence of phoneme sequences."
242
+ },
243
+ )
244
+
245
+
246
+ @dataclass
247
+ class DataCollatorCTCWithPadding:
248
+ """
249
+ Data collator that will dynamically pad the inputs received.
250
+ Args:
251
+ processor (:class:`~transformers.AutoProcessor`)
252
+ The processor used for proccessing the data.
253
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
254
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
255
+ among:
256
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
257
+ sequence if provided).
258
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
259
+ maximum acceptable input length for the model if that argument is not provided.
260
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
261
+ different lengths).
262
+ max_length (:obj:`int`, `optional`):
263
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
264
+ max_length_labels (:obj:`int`, `optional`):
265
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
266
+ pad_to_multiple_of (:obj:`int`, `optional`):
267
+ If set will pad the sequence to a multiple of the provided value.
268
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
269
+ 7.5 (Volta).
270
+ """
271
+
272
+ processor: AutoProcessor
273
+ padding: Union[bool, str] = "longest"
274
+ pad_to_multiple_of: Optional[int] = None
275
+ pad_to_multiple_of_labels: Optional[int] = None
276
+
277
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
278
+ # split inputs and labels since they have to be of different lenghts and need
279
+ # different padding methods
280
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
281
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
282
+
283
+ batch = self.processor.pad(
284
+ input_features,
285
+ padding=self.padding,
286
+ pad_to_multiple_of=self.pad_to_multiple_of,
287
+ return_tensors="pt",
288
+ )
289
+
290
+ with self.processor.as_target_processor():
291
+ labels_batch = self.processor.pad(
292
+ label_features,
293
+ padding=self.padding,
294
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
295
+ return_tensors="pt",
296
+ )
297
+
298
+ # replace padding with -100 to ignore loss correctly
299
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
300
+
301
+ batch["labels"] = labels
302
+
303
+ return batch
304
+
305
+
306
+ def create_vocabulary_from_data(
307
+ datasets: DatasetDict,
308
+ word_delimiter_token: Optional[str] = None,
309
+ unk_token: Optional[str] = None,
310
+ pad_token: Optional[str] = None,
311
+ ):
312
+ # Given training and test labels create vocabulary
313
+ def extract_all_chars(batch):
314
+ all_text = " ".join(batch["target_text"])
315
+ vocab = list(set(all_text))
316
+ return {"vocab": [vocab], "all_text": [all_text]}
317
+
318
+ vocabs = datasets.map(
319
+ extract_all_chars,
320
+ batched=True,
321
+ batch_size=-1,
322
+ keep_in_memory=True,
323
+ remove_columns=datasets["train"].column_names,
324
+ )
325
+
326
+ # take union of all unique characters in each dataset
327
+ vocab_set = functools.reduce(
328
+ lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values()
329
+ )
330
+
331
+ vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}
332
+
333
+ # replace white space with delimiter token
334
+ if word_delimiter_token is not None:
335
+ vocab_dict[word_delimiter_token] = vocab_dict[" "]
336
+ del vocab_dict[" "]
337
+
338
+ # add unk and pad token
339
+ if unk_token is not None:
340
+ vocab_dict[unk_token] = len(vocab_dict)
341
+
342
+ if pad_token is not None:
343
+ vocab_dict[pad_token] = len(vocab_dict)
344
+
345
+ return vocab_dict
346
+
347
+
348
+ def main():
349
+ # See all possible arguments in src/transformers/training_args.py
350
+ # or by passing the --help flag to this script.
351
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
352
+
353
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
354
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
355
+ # If we pass only one argument to the script and it's the path to a json file,
356
+ # let's parse it to get our arguments.
357
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
358
+ else:
359
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
360
+
361
+ # Detecting last checkpoint.
362
+ last_checkpoint = None
363
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
364
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
365
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
366
+ raise ValueError(
367
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
368
+ "Use --overwrite_output_dir to overcome."
369
+ )
370
+ elif last_checkpoint is not None:
371
+ logger.info(
372
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
373
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
374
+ )
375
+
376
+ # Setup logging
377
+ logging.basicConfig(
378
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
379
+ datefmt="%m/%d/%Y %H:%M:%S",
380
+ handlers=[logging.StreamHandler(sys.stdout)],
381
+ )
382
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
383
+
384
+ # Log on each process the small summary:
385
+ logger.warning(
386
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
387
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
388
+ )
389
+ # Set the verbosity to info of the Transformers logger (on main process only):
390
+ if is_main_process(training_args.local_rank):
391
+ transformers.utils.logging.set_verbosity_info()
392
+ logger.info("Training/evaluation parameters %s", training_args)
393
+
394
+ # Set seed before initializing model.
395
+ set_seed(training_args.seed)
396
+
397
+ # 1. First, let's load the dataset
398
+ raw_datasets = DatasetDict()
399
+
400
+ if training_args.do_train:
401
+ raw_datasets["train"] = load_dataset(
402
+ data_args.dataset_name,
403
+ data_args.dataset_config_name,
404
+ split=data_args.train_split_name,
405
+ use_auth_token=data_args.use_auth_token,
406
+ )
407
+
408
+ if data_args.audio_column_name not in raw_datasets["train"].column_names:
409
+ raise ValueError(
410
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
411
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
412
+ f"{', '.join(raw_datasets['train'].column_names)}."
413
+ )
414
+
415
+ if data_args.text_column_name not in raw_datasets["train"].column_names:
416
+ raise ValueError(
417
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
418
+ "Make sure to set `--text_column_name` to the correct text column - one of "
419
+ f"{', '.join(raw_datasets['train'].column_names)}."
420
+ )
421
+
422
+ if data_args.max_train_samples is not None:
423
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
424
+
425
+ if training_args.do_eval:
426
+ raw_datasets["eval"] = load_dataset(
427
+ data_args.dataset_name,
428
+ data_args.dataset_config_name,
429
+ split=data_args.eval_split_name,
430
+ use_auth_token=data_args.use_auth_token,
431
+ )
432
+
433
+ if data_args.max_eval_samples is not None:
434
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
435
+
436
+ # 2. We remove some special characters from the datasets
437
+ # that make training complicated and do not help in transcribing the speech
438
+ # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
439
+ # that could be easily picked up by the model
440
+ from pykakasi import kakasi
441
+
442
+ kakasi = kakasi()
443
+ kakasi.setMode('J', 'H') #Convert from kanji to hiragana
444
+ # kakasi.setMode("K", "H") #Convert from katakana to hiragana
445
+ conv = kakasi.getConverter()
446
+
447
+ chars_to_ignore_regex = (
448
+ f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else '[\,\?\!\-\;\:\"\“\%\‘\”\�\—\’\…\–\(\,\[\]\)\(\!]'
449
+ )
450
+ text_column_name = data_args.text_column_name
451
+
452
+
453
+
454
+ def remove_special_characters(batch):
455
+ if chars_to_ignore_regex is not None:
456
+ batch["target_text"] = conv.do(re.sub(chars_to_ignore_regex, "", batch[text_column_name])) + " "
457
+ else:
458
+ batch["target_text"] = batch[text_column_name].lower() + " "
459
+ return batch
460
+
461
+ with training_args.main_process_first(desc="dataset map special characters removal"):
462
+ raw_datasets = raw_datasets.map(
463
+ remove_special_characters,
464
+ remove_columns=[text_column_name],
465
+ desc="remove special characters from datasets",
466
+ )
467
+
468
+ # save special tokens for tokenizer
469
+ word_delimiter_token = data_args.word_delimiter_token
470
+ unk_token = data_args.unk_token
471
+ pad_token = data_args.pad_token
472
+
473
+ # 3. Next, let's load the config as we might need it to create
474
+ # the tokenizer
475
+ # load config
476
+ config = AutoConfig.from_pretrained(
477
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
478
+ )
479
+
480
+ # 4. Next, if no tokenizer file is defined,
481
+ # we create the vocabulary of the model by extracting all unique characters from
482
+ # the training and evaluation datasets
483
+ # We need to make sure that only first rank saves vocabulary
484
+ # make sure all processes wait until vocab is created
485
+ tokenizer_name_or_path = model_args.tokenizer_name_or_path
486
+ tokenizer_kwargs = {}
487
+ if tokenizer_name_or_path is None:
488
+ # save vocab in training output dir
489
+ tokenizer_name_or_path = training_args.output_dir
490
+
491
+ vocab_file = os.path.join(tokenizer_name_or_path, "vocab.json")
492
+
493
+ with training_args.main_process_first():
494
+ if training_args.overwrite_output_dir and os.path.isfile(vocab_file):
495
+ os.remove(vocab_file)
496
+
497
+ with training_args.main_process_first(desc="dataset map vocabulary creation"):
498
+ if not os.path.isfile(vocab_file):
499
+ os.makedirs(tokenizer_name_or_path, exist_ok=True)
500
+ vocab_dict = create_vocabulary_from_data(
501
+ raw_datasets,
502
+ word_delimiter_token=word_delimiter_token,
503
+ unk_token=unk_token,
504
+ pad_token=pad_token,
505
+ )
506
+
507
+ # save vocab dict to be loaded into tokenizer
508
+ with open(vocab_file, "w") as file:
509
+ json.dump(vocab_dict, file)
510
+
511
+ # if tokenizer has just been created
512
+ # it is defined by `tokenizer_class` if present in config else by `model_type`
513
+ tokenizer_kwargs = {
514
+ "config": config if config.tokenizer_class is not None else None,
515
+ "tokenizer_type": config.model_type if config.tokenizer_class is None else None,
516
+ "unk_token": unk_token,
517
+ "pad_token": pad_token,
518
+ "word_delimiter_token": word_delimiter_token,
519
+ }
520
+
521
+ # 5. Now we can instantiate the feature extractor, tokenizer and model
522
+ # Note for distributed training, the .from_pretrained methods guarantee that only
523
+ # one local process can concurrently download model & vocab.
524
+
525
+ # load feature_extractor and tokenizer
526
+ tokenizer = AutoTokenizer.from_pretrained(
527
+ tokenizer_name_or_path,
528
+ use_auth_token=data_args.use_auth_token,
529
+ **tokenizer_kwargs,
530
+ )
531
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
532
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
533
+ )
534
+
535
+ # adapt config
536
+ config.update(
537
+ {
538
+ "feat_proj_dropout": model_args.feat_proj_dropout,
539
+ "attention_dropout": model_args.attention_dropout,
540
+ "hidden_dropout": model_args.hidden_dropout,
541
+ "final_dropout": model_args.final_dropout,
542
+ "mask_time_prob": model_args.mask_time_prob,
543
+ "mask_time_length": model_args.mask_time_length,
544
+ "mask_feature_prob": model_args.mask_feature_prob,
545
+ "mask_feature_length": model_args.mask_feature_length,
546
+ "gradient_checkpointing": training_args.gradient_checkpointing,
547
+ "layerdrop": model_args.layerdrop,
548
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
549
+ "pad_token_id": tokenizer.pad_token_id,
550
+ "vocab_size": len(tokenizer),
551
+ "activation_dropout": model_args.activation_dropout,
552
+ }
553
+ )
554
+
555
+ # create model
556
+ model = AutoModelForCTC.from_pretrained(
557
+ model_args.model_name_or_path,
558
+ cache_dir=model_args.cache_dir,
559
+ config=config,
560
+ use_auth_token=data_args.use_auth_token,
561
+ )
562
+
563
+ # freeze encoder
564
+ if model_args.freeze_feature_encoder:
565
+ model.freeze_feature_encoder()
566
+
567
+ # 6. Now we preprocess the datasets including loading the audio, resampling and normalization
568
+ # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
569
+ # so that we just need to set the correct target sampling rate and normalize the input
570
+ # via the `feature_extractor`
571
+
572
+ # make sure that dataset decodes audio with correct sampling rate
573
+ dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
574
+ if dataset_sampling_rate != feature_extractor.sampling_rate:
575
+ raw_datasets = raw_datasets.cast_column(
576
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
577
+ )
578
+
579
+ # derive max & min input length for sample rate & max duration
580
+ max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
581
+ min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
582
+ audio_column_name = data_args.audio_column_name
583
+ num_workers = data_args.preprocessing_num_workers
584
+
585
+ # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
586
+ phoneme_language = data_args.phoneme_language
587
+
588
+ # Preprocessing the datasets.
589
+ # We need to read the audio files as arrays and tokenize the targets.
590
+ def prepare_dataset(batch):
591
+ # load audio
592
+ sample = batch[audio_column_name]
593
+
594
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
595
+ batch["input_values"] = inputs.input_values[0]
596
+ batch["input_length"] = len(batch["input_values"])
597
+
598
+ # encode targets
599
+ additional_kwargs = {}
600
+ if phoneme_language is not None:
601
+ additional_kwargs["phonemizer_lang"] = phoneme_language
602
+
603
+ batch["labels"] = tokenizer(batch["target_text"], **additional_kwargs).input_ids
604
+ return batch
605
+
606
+ with training_args.main_process_first(desc="dataset map preprocessing"):
607
+ vectorized_datasets = raw_datasets.map(
608
+ prepare_dataset,
609
+ remove_columns=next(iter(raw_datasets.values())).column_names,
610
+ num_proc=num_workers,
611
+ desc="preprocess datasets",
612
+ )
613
+
614
+ def is_audio_in_length_range(length):
615
+ return length > min_input_length and length < max_input_length
616
+
617
+ # filter data that is shorter than min_input_length
618
+ vectorized_datasets = vectorized_datasets.filter(
619
+ is_audio_in_length_range,
620
+ num_proc=num_workers,
621
+ input_columns=["input_length"],
622
+ )
623
+
624
+ # 7. Next, we can prepare the training.
625
+ # Let's use word error rate (WER) as our evaluation metric,
626
+ # instantiate a data collator and the trainer
627
+
628
+ # Define evaluation metrics during training, *i.e.* word error rate, character error rate
629
+ eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
630
+
631
+ # for large datasets it is advised to run the preprocessing on a
632
+ # single machine first with ``args.preprocessing_only`` since there will mostly likely
633
+ # be a timeout when running the script in distributed mode.
634
+ # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
635
+ # cached dataset
636
+ if data_args.preprocessing_only:
637
+ logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
638
+ return
639
+
640
+ def compute_metrics(pred):
641
+ pred_logits = pred.predictions
642
+ pred_ids = np.argmax(pred_logits, axis=-1)
643
+
644
+ pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
645
+
646
+ pred_str = tokenizer.batch_decode(pred_ids)
647
+ # we do not want to group tokens when computing the metrics
648
+ label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
649
+
650
+ metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
651
+
652
+ return metrics
653
+
654
+ # Now save everything to be able to create a single processor later
655
+ if is_main_process(training_args.local_rank):
656
+ # save feature extractor, tokenizer and config
657
+ feature_extractor.save_pretrained(training_args.output_dir)
658
+ tokenizer.save_pretrained(training_args.output_dir)
659
+ config.save_pretrained(training_args.output_dir)
660
+
661
+ try:
662
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
663
+ except (OSError, KeyError):
664
+ warnings.warn(
665
+ "Loading a processor from a feature extractor config that does not"
666
+ " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following "
667
+ " attribute to your `preprocessor_config.json` file to suppress this warning: "
668
+ " `'processor_class': 'Wav2Vec2Processor'`",
669
+ FutureWarning,
670
+ )
671
+ processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)
672
+
673
+ # Instantiate custom data collator
674
+ data_collator = DataCollatorCTCWithPadding(processor=processor)
675
+
676
+ decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm])
677
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
678
+ optimizer_grouped_parameters = [
679
+ {
680
+ "params": [p for n, p in model.named_parameters() if n in decay_parameters],
681
+ "weight_decay": training_args.weight_decay,
682
+ },
683
+ {
684
+ "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
685
+ "weight_decay": 0.0,
686
+ },
687
+ ]
688
+ optimizer = bnb.optim.Adam8bit(
689
+ params=optimizer_grouped_parameters,
690
+ lr=training_args.learning_rate,
691
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
692
+ eps=training_args.adam_epsilon,
693
+ )
694
+
695
+ optimizers = (optimizer, None)
696
+
697
+ # Initialize Trainer
698
+ trainer = Trainer(
699
+ model=model,
700
+ data_collator=data_collator,
701
+ args=training_args,
702
+ compute_metrics=compute_metrics,
703
+ train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
704
+ eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
705
+ tokenizer=feature_extractor,
706
+ optimizers=optimizers,
707
+ )
708
+
709
+ # 8. Finally, we can start training
710
+
711
+ # Training
712
+ if training_args.do_train:
713
+
714
+ # use last checkpoint if exist
715
+ if last_checkpoint is not None:
716
+ checkpoint = last_checkpoint
717
+ elif os.path.isdir(model_args.model_name_or_path):
718
+ checkpoint = model_args.model_name_or_path
719
+ else:
720
+ checkpoint = None
721
+
722
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
723
+ trainer.save_model()
724
+
725
+ metrics = train_result.metrics
726
+ max_train_samples = (
727
+ data_args.max_train_samples
728
+ if data_args.max_train_samples is not None
729
+ else len(vectorized_datasets["train"])
730
+ )
731
+ metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
732
+
733
+ trainer.log_metrics("train", metrics)
734
+ trainer.save_metrics("train", metrics)
735
+ trainer.save_state()
736
+
737
+ # Evaluation
738
+ results = {}
739
+ if training_args.do_eval:
740
+ logger.info("*** Evaluate ***")
741
+ metrics = trainer.evaluate()
742
+ max_eval_samples = (
743
+ data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
744
+ )
745
+ metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
746
+
747
+ trainer.log_metrics("eval", metrics)
748
+ trainer.save_metrics("eval", metrics)
749
+
750
+ # Write model card and (optionally) push to hub
751
+ config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
752
+ kwargs = {
753
+ "finetuned_from": model_args.model_name_or_path,
754
+ "tasks": "speech-recognition",
755
+ "tags": ["automatic-speech-recognition", data_args.dataset_name],
756
+ "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
757
+ "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
758
+ }
759
+ if "common_voice" in data_args.dataset_name:
760
+ kwargs["language"] = config_name
761
+
762
+ if training_args.push_to_hub:
763
+ trainer.push_to_hub(**kwargs)
764
+ else:
765
+ trainer.create_model_card(**kwargs)
766
+
767
+ return results
768
+
769
+
770
+ if __name__ == "__main__":
771
+ main()
run_training.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_speech_recognition_ctc_bnb.py \
2
+ --dataset_name="mozilla-foundation/common_voice_8_0" \
3
+ --model_name_or_path="facebook/wav2vec2-xls-r-300m" \
4
+ --dataset_config_name="ja" \
5
+ --output_dir="./" \
6
+ --overwrite_output_dir \
7
+ --num_train_epochs="10" \
8
+ --per_device_train_batch_size="48" \
9
+ --per_device_eval_batch_size="8" \
10
+ --learning_rate="7.5e-5" \
11
+ --warmup_steps="2000" \
12
+ --length_column_name="input_length" \
13
+ --evaluation_strategy="steps" \
14
+ --text_column_name="sentence" \
15
+ --save_steps="1000" \
16
+ --eval_steps="1000" \
17
+ --logging_steps="100" \
18
+ --layerdrop="0.0" \
19
+ --activation_dropout="0.1" \
20
+ --save_total_limit="4" \
21
+ --freeze_feature_encoder \
22
+ --feat_proj_dropout="0.0" \
23
+ --mask_time_prob="0.75" \
24
+ --mask_time_length="10" \
25
+ --mask_feature_prob="0.25" \
26
+ --mask_feature_length="64" \
27
+ --gradient_checkpointing \
28
+ --use_auth_token \
29
+ --fp16 \
30
+ --group_by_length \
31
+ --do_train --do_eval \
32
+ --push_to_hub
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]", "additional_special_tokens": [{"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}]}
speech_training_notebook.ipynb ADDED
@@ -0,0 +1,1490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "application/vnd.jupyter.widget-view+json": {
11
+ "model_id": "b7523cd66cf343f98fd3006be918a3b6",
12
+ "version_major": 2,
13
+ "version_minor": 0
14
+ },
15
+ "text/plain": [
16
+ "Downloading: 0%| | 0.00/10.1k [00:00<?, ?B/s]"
17
+ ]
18
+ },
19
+ "metadata": {},
20
+ "output_type": "display_data"
21
+ },
22
+ {
23
+ "data": {
24
+ "application/vnd.jupyter.widget-view+json": {
25
+ "model_id": "251cac7b8968405eafd54e2d29165b40",
26
+ "version_major": 2,
27
+ "version_minor": 0
28
+ },
29
+ "text/plain": [
30
+ "Downloading: 0%| | 0.00/2.98k [00:00<?, ?B/s]"
31
+ ]
32
+ },
33
+ "metadata": {},
34
+ "output_type": "display_data"
35
+ },
36
+ {
37
+ "data": {
38
+ "application/vnd.jupyter.widget-view+json": {
39
+ "model_id": "528c6a67efea4512b04b06a32156d5b7",
40
+ "version_major": 2,
41
+ "version_minor": 0
42
+ },
43
+ "text/plain": [
44
+ "Downloading: 0%| | 0.00/53.1k [00:00<?, ?B/s]"
45
+ ]
46
+ },
47
+ "metadata": {},
48
+ "output_type": "display_data"
49
+ },
50
+ {
51
+ "name": "stdout",
52
+ "output_type": "stream",
53
+ "text": [
54
+ "Downloading and preparing dataset common_voice/ja to /workspace/.cache/huggingface/datasets/mozilla-foundation___common_voice/ja/8.0.0/b8bc4d453193c06a43269b46cd87f075c70f152ac963b7f28f7a2760c45ec3e8...\n"
55
+ ]
56
+ },
57
+ {
58
+ "data": {
59
+ "application/vnd.jupyter.widget-view+json": {
60
+ "model_id": "6c21c5f782734b3bb3f545cef5b59ee0",
61
+ "version_major": 2,
62
+ "version_minor": 0
63
+ },
64
+ "text/plain": [
65
+ "Downloading: 0%| | 0.00/958M [00:00<?, ?B/s]"
66
+ ]
67
+ },
68
+ "metadata": {},
69
+ "output_type": "display_data"
70
+ },
71
+ {
72
+ "data": {
73
+ "application/vnd.jupyter.widget-view+json": {
74
+ "model_id": "",
75
+ "version_major": 2,
76
+ "version_minor": 0
77
+ },
78
+ "text/plain": [
79
+ "0 examples [00:00, ? examples/s]"
80
+ ]
81
+ },
82
+ "metadata": {},
83
+ "output_type": "display_data"
84
+ },
85
+ {
86
+ "data": {
87
+ "application/vnd.jupyter.widget-view+json": {
88
+ "model_id": "",
89
+ "version_major": 2,
90
+ "version_minor": 0
91
+ },
92
+ "text/plain": [
93
+ "0 examples [00:00, ? examples/s]"
94
+ ]
95
+ },
96
+ "metadata": {},
97
+ "output_type": "display_data"
98
+ },
99
+ {
100
+ "data": {
101
+ "application/vnd.jupyter.widget-view+json": {
102
+ "model_id": "",
103
+ "version_major": 2,
104
+ "version_minor": 0
105
+ },
106
+ "text/plain": [
107
+ "0 examples [00:00, ? examples/s]"
108
+ ]
109
+ },
110
+ "metadata": {},
111
+ "output_type": "display_data"
112
+ },
113
+ {
114
+ "data": {
115
+ "application/vnd.jupyter.widget-view+json": {
116
+ "model_id": "",
117
+ "version_major": 2,
118
+ "version_minor": 0
119
+ },
120
+ "text/plain": [
121
+ "0 examples [00:00, ? examples/s]"
122
+ ]
123
+ },
124
+ "metadata": {},
125
+ "output_type": "display_data"
126
+ },
127
+ {
128
+ "data": {
129
+ "application/vnd.jupyter.widget-view+json": {
130
+ "model_id": "",
131
+ "version_major": 2,
132
+ "version_minor": 0
133
+ },
134
+ "text/plain": [
135
+ "0 examples [00:00, ? examples/s]"
136
+ ]
137
+ },
138
+ "metadata": {},
139
+ "output_type": "display_data"
140
+ },
141
+ {
142
+ "name": "stdout",
143
+ "output_type": "stream",
144
+ "text": [
145
+ "Dataset common_voice downloaded and prepared to /workspace/.cache/huggingface/datasets/mozilla-foundation___common_voice/ja/8.0.0/b8bc4d453193c06a43269b46cd87f075c70f152ac963b7f28f7a2760c45ec3e8. Subsequent calls will reuse this data.\n"
146
+ ]
147
+ },
148
+ {
149
+ "name": "stderr",
150
+ "output_type": "stream",
151
+ "text": [
152
+ "Reusing dataset common_voice (/workspace/.cache/huggingface/datasets/mozilla-foundation___common_voice/ja/8.0.0/b8bc4d453193c06a43269b46cd87f075c70f152ac963b7f28f7a2760c45ec3e8)\n"
153
+ ]
154
+ },
155
+ {
156
+ "name": "stdout",
157
+ "output_type": "stream",
158
+ "text": [
159
+ "10623\n"
160
+ ]
161
+ }
162
+ ],
163
+ "source": [
164
+ "from datasets import Audio, Dataset, load_dataset, load_metric\n",
165
+ "from transformers import AutoFeatureExtractor, pipeline\n",
166
+ "\n",
167
+ "language_code = \"ja\"\n",
168
+ "dataset_name = \"mozilla-foundation/common_voice_8_0\"\n",
169
+ "\n",
170
+ "common_voice_train = load_dataset(dataset_name, language_code, use_auth_token=True, split=\"train+validation\")\n",
171
+ "common_voice_test = load_dataset(dataset_name, language_code, use_auth_token=True, split=\"test\")\n",
172
+ "\n",
173
+ "\n",
174
+ "print(len(common_voice_train))\n",
175
+ "\n",
176
+ "# # for testing: only process the first two examples as a test\n",
177
+ "# dataset = dataset.select(range(10))\n",
178
+ "\n",
179
+ "\n"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": 1,
185
+ "metadata": {},
186
+ "outputs": [
187
+ {
188
+ "name": "stdout",
189
+ "output_type": "stream",
190
+ "text": [
191
+ "Collecting pykakasi\n",
192
+ " Downloading pykakasi-2.2.1-py3-none-any.whl (2.4 MB)\n",
193
+ " |████████████████████████████████| 2.4 MB 9.9 MB/s \n",
194
+ "\u001b[?25hCollecting jaconv\n",
195
+ " Downloading jaconv-0.3.tar.gz (15 kB)\n",
196
+ " Preparing metadata (setup.py) ... \u001b[?25ldone\n",
197
+ "\u001b[?25hCollecting deprecated\n",
198
+ " Downloading Deprecated-1.2.13-py2.py3-none-any.whl (9.6 kB)\n",
199
+ "Collecting wrapt<2,>=1.10\n",
200
+ " Downloading wrapt-1.13.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (84 kB)\n",
201
+ " |████████████████████████████████| 84 kB 12.8 MB/s \n",
202
+ "\u001b[?25hBuilding wheels for collected packages: jaconv\n",
203
+ " Building wheel for jaconv (setup.py) ... \u001b[?25ldone\n",
204
+ "\u001b[?25h Created wheel for jaconv: filename=jaconv-0.3-py3-none-any.whl size=15553 sha256=fd764f215e4d567cb60062a7052497b66729e9e2190e2e00153e0d19734088e7\n",
205
+ " Stored in directory: /workspace/.cache/pip/wheels/73/e8/fb/b4ad8117719f79ac73bc05406d1768f845688cdbeed7aad87e\n",
206
+ "Successfully built jaconv\n",
207
+ "Installing collected packages: wrapt, jaconv, deprecated, pykakasi\n",
208
+ "Successfully installed deprecated-1.2.13 jaconv-0.3 pykakasi-2.2.1 wrapt-1.13.3\n",
209
+ "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.0.2 is available.\n",
210
+ "You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
211
+ ]
212
+ }
213
+ ],
214
+ "source": [
215
+ "!pip install pykakasi"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": 4,
221
+ "metadata": {},
222
+ "outputs": [
223
+ {
224
+ "name": "stdout",
225
+ "output_type": "stream",
226
+ "text": [
227
+ "にんじゃ ひらがな kana\n"
228
+ ]
229
+ },
230
+ {
231
+ "name": "stderr",
232
+ "output_type": "stream",
233
+ "text": [
234
+ "/tmp/ipykernel_2159/3076271513.py:4: DeprecationWarning: Call to deprecated method setMode. (Old API will be removed in v3.0.) -- Deprecated since version 2.1.\n",
235
+ " kakasi.setMode('J', 'H') #Convert from kanji to hiragana\n",
236
+ "/tmp/ipykernel_2159/3076271513.py:6: DeprecationWarning: Call to deprecated method getConverter. (Old API will be removed in v3.0.) -- Deprecated since version 2.1.\n",
237
+ " conv = kakasi.getConverter()\n",
238
+ "/tmp/ipykernel_2159/3076271513.py:10: DeprecationWarning: Call to deprecated method do. (Old API will be removed in v3.0.) -- Deprecated since version 2.1.\n",
239
+ " print(conv.do(str))\n"
240
+ ]
241
+ }
242
+ ],
243
+ "source": [
244
+ "from pykakasi import kakasi\n",
245
+ "\n",
246
+ "kakasi = kakasi()\n",
247
+ "kakasi.setMode('J', 'H') #Convert from kanji to hiragana\n",
248
+ "# kakasi.setMode(\"K\", \"H\") #Convert from katakana to hiragana\n",
249
+ "conv = kakasi.getConverter()\n",
250
+ "\n",
251
+ "str = 'にんじゃ 平仮名 kana'\n",
252
+ "\n",
253
+ "print(conv.do(str))"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": 3,
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": [
262
+ "repo_name = 'https://huggingface.co/AndrewMcDowell/wav2vec2-xls-r-1B-german'\n"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": 4,
268
+ "metadata": {},
269
+ "outputs": [],
270
+ "source": [
271
+ "common_voice_train = common_voice_train.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])\n",
272
+ "common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])\n",
273
+ "\n"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": 15,
279
+ "metadata": {},
280
+ "outputs": [
281
+ {
282
+ "data": {
283
+ "application/vnd.jupyter.widget-view+json": {
284
+ "model_id": "ad26c4d7d02948a3bc30d86a0f3527c8",
285
+ "version_major": 2,
286
+ "version_minor": 0
287
+ },
288
+ "text/plain": [
289
+ "0ex [00:00, ?ex/s]"
290
+ ]
291
+ },
292
+ "metadata": {},
293
+ "output_type": "display_data"
294
+ },
295
+ {
296
+ "name": "stderr",
297
+ "output_type": "stream",
298
+ "text": [
299
+ "/tmp/ipykernel_2159/322450745.py:5: DeprecationWarning: Call to deprecated method do. (Old API will be removed in v3.0.) -- Deprecated since version 2.1.\n",
300
+ " batch[\"sentence\"] = conv.do(re.sub(chars_to_remove_regex, '', batch[\"sentence\"]))\n"
301
+ ]
302
+ },
303
+ {
304
+ "data": {
305
+ "application/vnd.jupyter.widget-view+json": {
306
+ "model_id": "93295f1cd50f4557a96ff1bf139c9a37",
307
+ "version_major": 2,
308
+ "version_minor": 0
309
+ },
310
+ "text/plain": [
311
+ "0ex [00:00, ?ex/s]"
312
+ ]
313
+ },
314
+ "metadata": {},
315
+ "output_type": "display_data"
316
+ }
317
+ ],
318
+ "source": [
319
+ "import re\n",
320
+ "chars_to_remove_regex = '[\\,\\?\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�\\—\\’\\…\\–\\(\\,\\[\\]\\)\\(\\!]'\n",
321
+ "# \\.\n",
322
+ "def remove_special_characters(batch):\n",
323
+ " batch[\"sentence\"] = conv.do(re.sub(chars_to_remove_regex, '', batch[\"sentence\"]))\n",
324
+ " return batch\n",
325
+ "\n",
326
+ "common_voice_train = common_voice_train.map(remove_special_characters)\n",
327
+ "common_voice_test = common_voice_test.map(remove_special_characters)"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": 6,
333
+ "metadata": {},
334
+ "outputs": [
335
+ {
336
+ "name": "stdout",
337
+ "output_type": "stream",
338
+ "text": [
339
+ "Collecting num2words\n",
340
+ " Downloading num2words-0.5.10-py3-none-any.whl (101 kB)\n",
341
+ " |████████████████████████████████| 101 kB 7.9 MB/s \n",
342
+ "\u001b[?25hCollecting docopt>=0.6.2\n",
343
+ " Downloading docopt-0.6.2.tar.gz (25 kB)\n",
344
+ " Preparing metadata (setup.py) ... \u001b[?25ldone\n",
345
+ "\u001b[?25hBuilding wheels for collected packages: docopt\n",
346
+ " Building wheel for docopt (setup.py) ... \u001b[?25ldone\n",
347
+ "\u001b[?25h Created wheel for docopt: filename=docopt-0.6.2-py2.py3-none-any.whl size=13704 sha256=7cda85e4b3980668714aad8f5d706fb5b189c2804ce1d99ca2380537fdc78031\n",
348
+ " Stored in directory: /workspace/.cache/pip/wheels/56/ea/58/ead137b087d9e326852a851351d1debf4ada529b6ac0ec4e8c\n",
349
+ "Successfully built docopt\n",
350
+ "Installing collected packages: docopt, num2words\n",
351
+ "Successfully installed docopt-0.6.2 num2words-0.5.10\n",
352
+ "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.0.2 is available.\n",
353
+ "You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
354
+ ]
355
+ }
356
+ ],
357
+ "source": [
358
+ "!pip install num2words"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": 7,
364
+ "metadata": {},
365
+ "outputs": [
366
+ {
367
+ "data": {
368
+ "application/vnd.jupyter.widget-view+json": {
369
+ "model_id": "0da8fd9cdae64c1fa80fbcfc412bcf9c",
370
+ "version_major": 2,
371
+ "version_minor": 0
372
+ },
373
+ "text/plain": [
374
+ "0ex [00:00, ?ex/s]"
375
+ ]
376
+ },
377
+ "metadata": {},
378
+ "output_type": "display_data"
379
+ }
380
+ ],
381
+ "source": [
382
+ "\n",
383
+ "from num2words import num2words\n",
384
+ "import regex as re\n",
385
+ "matches = []\n",
386
+ "\n",
387
+ "def replace_numbers(match):\n",
388
+ " match = match.group()\n",
389
+ " matches.append(match)\n",
390
+ " return num2words(match, lang='de')\n",
391
+ "\n",
392
+ "def replace_numbers_in_batch(batch):\n",
393
+ " batch[\"sentence\"] = re.sub(r'\\d+(?:,\\d+)?', replace_numbers, batch[\"sentence\"])\n",
394
+ " return batch\n",
395
+ "\n",
396
+ "common_voice_test_2 = common_voice_test.map(replace_numbers_in_batch)"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": 10,
402
+ "metadata": {},
403
+ "outputs": [
404
+ {
405
+ "data": {
406
+ "application/vnd.jupyter.widget-view+json": {
407
+ "model_id": "54d62ea7a0214b6abc5de1f106b330dc",
408
+ "version_major": 2,
409
+ "version_minor": 0
410
+ },
411
+ "text/plain": [
412
+ "0ex [00:00, ?ex/s]"
413
+ ]
414
+ },
415
+ "metadata": {},
416
+ "output_type": "display_data"
417
+ }
418
+ ],
419
+ "source": [
420
+ "common_voice_train_2 = common_voice_train.map(replace_numbers_in_batch)"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": 11,
426
+ "metadata": {},
427
+ "outputs": [
428
+ {
429
+ "data": {
430
+ "text/plain": [
431
+ "0"
432
+ ]
433
+ },
434
+ "execution_count": 11,
435
+ "metadata": {},
436
+ "output_type": "execute_result"
437
+ }
438
+ ],
439
+ "source": [
440
+ "len(matches)"
441
+ ]
442
+ },
443
+ {
444
+ "cell_type": "code",
445
+ "execution_count": null,
446
+ "metadata": {},
447
+ "outputs": [],
448
+ "source": [
449
+ "# def replace_accented_characters(batch):\n",
450
+ "# accented_string = u'Málaga'\n",
451
+ "# # accented_string is of type 'unicode'\n",
452
+ "# import unidecode\n",
453
+ "# unaccented_string = unidecode.unidecode(accented_string)\n",
454
+ "# batch[\"sentence\"] = re.sub('[â]', 'a', batch[\"sentence\"])\n",
455
+ "# batch[\"sentence\"] = re.sub('[î]', 'i', batch[\"sentence\"])\n",
456
+ "# batch[\"sentence\"] = re.sub('[ô]', 'o', batch[\"sentence\"])\n",
457
+ "# batch[\"sentence\"] = re.sub('[û]', 'u', batch[\"sentence\"])\n",
458
+ "# return batch\n",
459
+ "\n",
460
+ "def strip_accents(batch):\n",
461
+ " return ''.join(c for c in unicodedata.normalize('NFD', batch[\"sentence\"]) if unicodedata.category(c) != 'Mn')\n",
462
+ "\n",
463
+ "common_voice_train = common_voice_train.map(replace_accented_characters)\n",
464
+ "common_voice_test = common_voice_test.map(replace_accented_characters)"
465
+ ]
466
+ },
467
+ {
468
+ "cell_type": "code",
469
+ "execution_count": null,
470
+ "metadata": {},
471
+ "outputs": [],
472
+ "source": []
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": 6,
477
+ "metadata": {},
478
+ "outputs": [],
479
+ "source": [
480
+ "def extract_all_chars(batch):\n",
481
+ " all_text = \" \".join(batch[\"sentence\"])\n",
482
+ " vocab = list(set(all_text))\n",
483
+ " return {\"vocab\": [vocab], \"all_text\": [all_text]}"
484
+ ]
485
+ },
486
+ {
487
+ "cell_type": "code",
488
+ "execution_count": null,
489
+ "metadata": {},
490
+ "outputs": [],
491
+ "source": []
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "execution_count": 16,
496
+ "metadata": {},
497
+ "outputs": [
498
+ {
499
+ "data": {
500
+ "application/vnd.jupyter.widget-view+json": {
501
+ "model_id": "c40f4d6b6bb74a56b2c570a3a53d7f4b",
502
+ "version_major": 2,
503
+ "version_minor": 0
504
+ },
505
+ "text/plain": [
506
+ " 0%| | 0/1 [00:00<?, ?ba/s]"
507
+ ]
508
+ },
509
+ "metadata": {},
510
+ "output_type": "display_data"
511
+ },
512
+ {
513
+ "data": {
514
+ "application/vnd.jupyter.widget-view+json": {
515
+ "model_id": "f69b6a3c0b54477ea15c56b02464bacd",
516
+ "version_major": 2,
517
+ "version_minor": 0
518
+ },
519
+ "text/plain": [
520
+ " 0%| | 0/1 [00:00<?, ?ba/s]"
521
+ ]
522
+ },
523
+ "metadata": {},
524
+ "output_type": "display_data"
525
+ }
526
+ ],
527
+ "source": [
528
+ "vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)\n",
529
+ "vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "code",
534
+ "execution_count": 17,
535
+ "metadata": {},
536
+ "outputs": [],
537
+ "source": [
538
+ "vocab_list = list(set(vocab_train[\"vocab\"][0]) | set(vocab_test[\"vocab\"][0]))"
539
+ ]
540
+ },
541
+ {
542
+ "cell_type": "code",
543
+ "execution_count": 19,
544
+ "metadata": {
545
+ "collapsed": true,
546
+ "jupyter": {
547
+ "outputs_hidden": true
548
+ }
549
+ },
550
+ "outputs": [
551
+ {
552
+ "data": {
553
+ "text/plain": [
554
+ "[['ん',\n",
555
+ " 'ン',\n",
556
+ " 'ダ',\n",
557
+ " 'S',\n",
558
+ " 'う',\n",
559
+ " 'た',\n",
560
+ " 'ぽ',\n",
561
+ " 'P',\n",
562
+ " ':',\n",
563
+ " '々',\n",
564
+ " 'か',\n",
565
+ " 'ぞ',\n",
566
+ " 'よ',\n",
567
+ " 'や',\n",
568
+ " 'ヨ',\n",
569
+ " 'ゃ',\n",
570
+ " 'Q',\n",
571
+ " 'N',\n",
572
+ " 'だ',\n",
573
+ " 'を',\n",
574
+ " 'L',\n",
575
+ " 'h',\n",
576
+ " 'F',\n",
577
+ " 'E',\n",
578
+ " 'ピ',\n",
579
+ " 'ち',\n",
580
+ " 'ボ',\n",
581
+ " 'w',\n",
582
+ " 'リ',\n",
583
+ " 'ゲ',\n",
584
+ " 'フ',\n",
585
+ " 'あ',\n",
586
+ " 'ウ',\n",
587
+ " 'め',\n",
588
+ " 'タ',\n",
589
+ " 'ぬ',\n",
590
+ " 'せ',\n",
591
+ " 'て',\n",
592
+ " 'b',\n",
593
+ " '」',\n",
594
+ " 'す',\n",
595
+ " 'び',\n",
596
+ " 'ば',\n",
597
+ " 'ア',\n",
598
+ " 'A',\n",
599
+ " 'r',\n",
600
+ " 'ャ',\n",
601
+ " 'イ',\n",
602
+ " 'へ',\n",
603
+ " 'ぶ',\n",
604
+ " 'は',\n",
605
+ " 'u',\n",
606
+ " 'と',\n",
607
+ " '繫',\n",
608
+ " 'ぎ',\n",
609
+ " 'バ',\n",
610
+ " 'ノ',\n",
611
+ " 'I',\n",
612
+ " 'ざ',\n",
613
+ " 'R',\n",
614
+ " 'チ',\n",
615
+ " 'A',\n",
616
+ " '「',\n",
617
+ " 'G',\n",
618
+ " 'ェ',\n",
619
+ " 'く',\n",
620
+ " 'け',\n",
621
+ " 'ぇ',\n",
622
+ " '?',\n",
623
+ " '〜',\n",
624
+ " 'つ',\n",
625
+ " 'わ',\n",
626
+ " 'こ',\n",
627
+ " 'ス',\n",
628
+ " 'ズ',\n",
629
+ " 'p',\n",
630
+ " 'y',\n",
631
+ " 'ぼ',\n",
632
+ " 'し',\n",
633
+ " '、',\n",
634
+ " '!',\n",
635
+ " 'ゼ',\n",
636
+ " 's',\n",
637
+ " 'U',\n",
638
+ " 'き',\n",
639
+ " 'ゥ',\n",
640
+ " '・',\n",
641
+ " 'が',\n",
642
+ " 'も',\n",
643
+ " 'エ',\n",
644
+ " 'ク',\n",
645
+ " 'づ',\n",
646
+ " 'O',\n",
647
+ " 'グ',\n",
648
+ " 'ブ',\n",
649
+ " 'ゅ',\n",
650
+ " 'ィ',\n",
651
+ " 'ぁ',\n",
652
+ " 'd',\n",
653
+ " 't',\n",
654
+ " 'j',\n",
655
+ " 'n',\n",
656
+ " 'ロ',\n",
657
+ " 'g',\n",
658
+ " 'ー',\n",
659
+ " '/',\n",
660
+ " 'ナ',\n",
661
+ " 'ヅ',\n",
662
+ " 'の',\n",
663
+ " 'ケ',\n",
664
+ " 'ほ',\n",
665
+ " '・',\n",
666
+ " ')',\n",
667
+ " 'J',\n",
668
+ " 'D',\n",
669
+ " 'ネ',\n",
670
+ " 'お',\n",
671
+ " 'パ',\n",
672
+ " 'ム',\n",
673
+ " 'む',\n",
674
+ " 'ラ',\n",
675
+ " 'ミ',\n",
676
+ " 'い',\n",
677
+ " 'ろ',\n",
678
+ " 'c',\n",
679
+ " '=',\n",
680
+ " 'z',\n",
681
+ " 'ベ',\n",
682
+ " 'O',\n",
683
+ " 'h',\n",
684
+ " 'プ',\n",
685
+ " 'o',\n",
686
+ " 'ザ',\n",
687
+ " '&',\n",
688
+ " '『',\n",
689
+ " 'ソ',\n",
690
+ " '.',\n",
691
+ " 'ヴ',\n",
692
+ " 'l',\n",
693
+ " 'ド',\n",
694
+ " 'み',\n",
695
+ " 'v',\n",
696
+ " 'x',\n",
697
+ " 'Y',\n",
698
+ " 'ガ',\n",
699
+ " 'に',\n",
700
+ " 'ヌ',\n",
701
+ " 'ら',\n",
702
+ " 'ヘ',\n",
703
+ " 'ょ',\n",
704
+ " 'カ',\n",
705
+ " '。',\n",
706
+ " 'ギ',\n",
707
+ " 'C',\n",
708
+ " 'ぜ',\n",
709
+ " 'モ',\n",
710
+ " 'キ',\n",
711
+ " 'i',\n",
712
+ " 'j',\n",
713
+ " '.',\n",
714
+ " \"'\",\n",
715
+ " 'M',\n",
716
+ " 'ご',\n",
717
+ " 'ど',\n",
718
+ " 'ハ',\n",
719
+ " 'ね',\n",
720
+ " 'で',\n",
721
+ " 'W',\n",
722
+ " 'ぴ',\n",
723
+ " 'T',\n",
724
+ " 'ぷ',\n",
725
+ " ' ',\n",
726
+ " 'マ',\n",
727
+ " '―',\n",
728
+ " 'ビ',\n",
729
+ " 'H',\n",
730
+ " 'デ',\n",
731
+ " 'f',\n",
732
+ " 'ゾ',\n",
733
+ " '-',\n",
734
+ " 'ポ',\n",
735
+ " 'K',\n",
736
+ " 'ヤ',\n",
737
+ " 'ユ',\n",
738
+ " 'シ',\n",
739
+ " 'ペ',\n",
740
+ " 'Z',\n",
741
+ " 'ぱ',\n",
742
+ " 'ふ',\n",
743
+ " 'る',\n",
744
+ " 'べ',\n",
745
+ " 'ヒ',\n",
746
+ " 'e',\n",
747
+ " 'そ',\n",
748
+ " 'テ',\n",
749
+ " 'サ',\n",
750
+ " 'V',\n",
751
+ " 'れ',\n",
752
+ " '」',\n",
753
+ " 'じ',\n",
754
+ " 'ワ',\n",
755
+ " 'レ',\n",
756
+ " 'X',\n",
757
+ " 'ォ',\n",
758
+ " 'ュ',\n",
759
+ " 'ジ',\n",
760
+ " 'k',\n",
761
+ " 'な',\n",
762
+ " 'ニ',\n",
763
+ " 'り',\n",
764
+ " 'q',\n",
765
+ " 'U',\n",
766
+ " 'ひ',\n",
767
+ " 'げ',\n",
768
+ " '&',\n",
769
+ " 'ゆ',\n",
770
+ " 'っ',\n",
771
+ " 'ず',\n",
772
+ " 'ゴ',\n",
773
+ " '「',\n",
774
+ " 'a',\n",
775
+ " 'ぢ',\n",
776
+ " 'ル',\n",
777
+ " 'さ',\n",
778
+ " 'ぺ',\n",
779
+ " 'm',\n",
780
+ " 'ョ',\n",
781
+ " 'ト',\n",
782
+ " 'ツ',\n",
783
+ " 'ホ',\n",
784
+ " 'コ',\n",
785
+ " 'オ',\n",
786
+ " 'セ',\n",
787
+ " 'え',\n",
788
+ " 'ま',\n",
789
+ " 'メ',\n",
790
+ " 'ァ',\n",
791
+ " 'F',\n",
792
+ " 'ぐ',\n",
793
+ " 'B',\n",
794
+ " '』',\n",
795
+ " 'ッ']]"
796
+ ]
797
+ },
798
+ "execution_count": 19,
799
+ "metadata": {},
800
+ "output_type": "execute_result"
801
+ }
802
+ ],
803
+ "source": [
804
+ "# vocab_train[\"vocab\"]"
805
+ ]
806
+ },
807
+ {
808
+ "cell_type": "code",
809
+ "execution_count": 18,
810
+ "metadata": {},
811
+ "outputs": [
812
+ {
813
+ "name": "stdout",
814
+ "output_type": "stream",
815
+ "text": [
816
+ "249\n",
817
+ "['ダ', 'た', 'P', 'か', 'よ', 'や', 'Q', 'を', 'F', 'h', 'E', 'ち', 'リ', 'ゲ', 'フ', 'め', 'タ', 'せ', 'b', '」', 'ば', 'ア', 'A', 'ャ', 'イ', 'ぶ', 'は', 'u', 'と', 'ノ', 'I', 'R', '「', 'G', 'ェ', 'く', '?', '〜', 'つ', 'こ', 'S', 'ぼ', 'ゼ', 's', 'U', 'き', 'ゥ', 'が', 'も', 'エ', 'ク', 'づ', 'グ', 'ブ', 'ゅ', 'ィ', 't', 'n', 'ロ', 'ー', '/', 'の', 'ケ', '・', 'J', 'お', 'む', 'P', 'ベ', 'h', 'プ', 'o', '&', '『', 'ソ', '.', 'ヴ', 'ド', 'み', 'Y', 'ガ', 'ょ', 'カ', 'C', 'ぜ', 'j', '.', 'ご', 'ど', 'ハ', 'ね', 'W', 'j', 'T', ' ', 'マ', '―', '-', 'デ', 'ゾ', 'ポ', 'K', 'ペ', 'ぱ', 'ふ', 'べ', 'ヒ', 'e', 'サ', 'N', 'X', 'ュ', 'k', 'り', 'U', 'ひ', 'げ', 'ゆ', 'ず', 'ゴ', 'a', 'ョ', 'ツ', '〇', 'え', 'F', 'B', '』', 'ッ', 'ん', 'ン', 'S', 'う', 'ぽ', ':', '々', 'ぞ', 'N', 'ヨ', 'ゃ', 'だ', 'L', 'ピ', 'ボ', 'w', 'ウ', 'あ', 'ヶ', 'ぬ', 'て', 'す', 'び', 'r', 'へ', '繫', 'バ', 'ぎ', 'ざ', 'A', 'チ', 'け', 'ぇ', 'わ', 'ス', 'p', 'ズ', 'y', 'し', '、', '!', 'G', '・', 'O', 'ぁ', 'd', 'g', 'ナ', 'ヅ', 'ほ', ')', 'D', 'ネ', 'パ', 'ム', 'ミ', '=', 'z', 'い', 'ろ', 'c', 'O', 'ザ', 'l', 'v', 'x', 'ヌ', 'に', 'ら', 'ヘ', '。', 'ギ', 'モ', 'D', 'キ', 'i', \"'\", 'M', 'で', 'ぴ', 'ぷ', 'ビ', 'H', 'f', 'ヤ', 'ユ', 'シ', 'Z', 'る', 'そ', 'テ', 'V', 'れ', '」', 'じ', 'ワ', 'レ', 'ォ', 'ジ', 'な', 'ニ', 'q', '&', 'っ', '「', 'ぢ', 'ル', 'さ', 'ぺ', 'm', 'ト', 'ホ', 'コ', 'オ', 'セ', 'ま', 'メ', 'ァ', 'ぐ', 'ラ']\n"
818
+ ]
819
+ }
820
+ ],
821
+ "source": [
822
+ "print(len(vocab_list))\n",
823
+ "print(vocab_list)"
824
+ ]
825
+ },
826
+ {
827
+ "cell_type": "code",
828
+ "execution_count": 26,
829
+ "metadata": {},
830
+ "outputs": [],
831
+ "source": [
832
+ "j_vocab = {\"<pad>\": 0, \"<s>\": 1, \"</s>\": 2, \"<unk>\": 3, \"|\": 4, \"'\": 5, \"-\": 6, \"A\": 7, \"B\": 8, \"C\": 9, \"D\": 10, \"E\": 11, \"F\": 12, \"G\": 13, \"H\": 14, \"I\": 15, \"J\": 16, \"K\": 17, \"L\": 18, \"M\": 19, \"N\": 20, \"O\": 21, \"P\": 22, \"Q\": 23, \"R\": 24, \"S\": 25, \"T\": 26, \"U\": 27, \"V\": 28, \"W\": 29, \"X\": 30, \"Y\": 31, \"Z\": 32, \"Ä\": 33, \"Í\": 34, \"Ó\": 35, \"Ö\": 36, \"Ü\": 37}\n"
833
+ ]
834
+ },
835
+ {
836
+ "cell_type": "code",
837
+ "execution_count": 48,
838
+ "metadata": {},
839
+ "outputs": [],
840
+ "source": [
841
+ "manually_kept_values = ['ß', 'ä', 'ö', 'ü']\n",
842
+ "\n",
843
+ "punctuation = ['.', ]"
844
+ ]
845
+ },
846
+ {
847
+ "cell_type": "code",
848
+ "execution_count": 50,
849
+ "metadata": {},
850
+ "outputs": [
851
+ {
852
+ "name": "stdout",
853
+ "output_type": "stream",
854
+ "text": [
855
+ "['$', '&', '(', ')', '*', '+', '.', '/', '=', '@', '[', ']', '_', '`', '¡', '§', '«', '°', '´', 'µ', '·', '»', '×', 'à', 'á', 'â', 'ã', 'å', 'æ', 'ç', 'è', 'é', 'ê', 'ë', 'ì', 'í', 'î', 'ï', 'ð', 'ñ', 'ò', 'ó', 'ô', 'õ', 'ø', 'ù', 'ú', 'û', 'ý', 'þ', 'ā', 'ă', 'ą', 'ć', 'č', 'ď', 'đ', 'ē', 'ė', 'ę', 'ě', 'ğ', 'ġ', 'ħ', 'ī', 'ı', 'ł', 'ń', 'ņ', 'ň', 'ō', 'ŏ', 'ő', 'œ', 'ř', 'ś', 'ş', 'š', 'ť', 'ū', 'ů', 'ź', 'ż', 'ž', 'ơ', 'ǐ', 'ǔ', 'ș', 'ț', 'ə', 'ʻ', 'ʾ', 'ʿ', '̆', '̇', '̥', 'а', 'в', 'е', 'и', 'к', 'м', 'о', 'р', 'с', 'ф', 'ч', 'ш', 'ѹ', 'א', 'ב', 'נ', 'ע', 'ש', '་', 'ན', 'ḫ', 'ṟ', 'ṣ', 'ṭ', 'ạ', 'ả', 'ắ', 'ằ', 'ế', 'ễ', 'ệ', 'ọ', 'ồ', 'ộ', 'ụ', 'ứ', '‑', '‚', '„', '‟', '′', '″', '‹', '›', '→', '−', '≡', '⟨', '⟩', 'カ', '东', '临', '乡', '关', '合', '城', '孙', '尣', '幺', '支', '比', '毛', '泽', '無', '生', '臣', '辶', '道', '镇', '黃']\n"
856
+ ]
857
+ }
858
+ ],
859
+ "source": [
860
+ "odd_values = []\n",
861
+ "for index, value in enumerate(sorted(vocab_list)):\n",
862
+ "# if :\n",
863
+ " if value not in j_vocab and not (16 <= index <= 41 or value == ' ') and value not in manually_kept_values:\n",
864
+ " odd_values.append(value)\n",
865
+ "# print(index, value)\n",
866
+ " \n",
867
+ "print(odd_values)"
868
+ ]
869
+ },
870
+ {
871
+ "cell_type": "code",
872
+ "execution_count": 63,
873
+ "metadata": {},
874
+ "outputs": [
875
+ {
876
+ "name": "stdout",
877
+ "output_type": "stream",
878
+ "text": [
879
+ "$ & ( ) * + . / = @ [ ] _ ` ¡ § « ° ´ µ · » × à á â ã å æ ç è é ê ë ì í î ï ð ñ ò ó ô õ ø ù ú û ý þ ā ă ą ć č ď đ ē ė ę ě ğ ġ ħ ī ı ł ń ņ ň ō ŏ ő œ ř ś ş š ť ū ů ź ż ž ơ ǐ ǔ ș ț ə ʻ ʾ ʿ ̆ ̇ ̥ а в е и к м о р с ф ч ш ѹ א ב נ ע ש ་ ན ḫ ṟ ṣ ṭ ạ ả ắ ằ ế ễ ệ ọ ồ ộ ụ ứ ‑ ‚ „ ‟ ′ ″ ‹ › → − ≡ ⟨ ⟩ カ 东 临 乡 关 合 城 孙 尣 幺 支 比 毛 泽 無 生 臣 辶 道 镇 黃\n"
880
+ ]
881
+ }
882
+ ],
883
+ "source": [
884
+ "print(\" \".join(odd_values))\n",
885
+ "\n",
886
+ "# for value in odd_values:\n",
887
+ "# if value not in manually_kept_values:\n",
888
+ "# print(value)"
889
+ ]
890
+ },
891
+ {
892
+ "cell_type": "code",
893
+ "execution_count": null,
894
+ "metadata": {},
895
+ "outputs": [],
896
+ "source": [
897
+ "$ & ( ) * + = @ [ ] _ ` ¡ § « ° ´ µ · » × à á â ã å æ ç è é ê ë ì í î ï ð ñ ò ó ô õ ø ù ú û ý þ ā ă ą ć č ď đ ē ė ę ě ğ ġ ħ ī ı ł ń ņ ň ō ŏ ő œ ř ś ş š ť ū ů ź ż ž ơ ǐ ǔ ș ț ə ʻ ʾ ʿ ̆ ̇ ̥ а в е и к м о р с ф ч ш ѹ א ב נ ע ש ་ ན ḫ ṟ ṣ ṭ ạ ả ắ ằ ế ễ ệ ọ ồ ộ ụ ứ ‑ ‚ „ ‟ ′ ″ ‹ › → − ≡ ⟨ ⟩ カ 东 临 乡 关 合 城 孙 尣 幺 支 比 毛 泽 無 生 臣 辶 道 镇 黃"
898
+ ]
899
+ },
900
+ {
901
+ "cell_type": "code",
902
+ "execution_count": 54,
903
+ "metadata": {},
904
+ "outputs": [],
905
+ "source": [
906
+ "filtered_vocab_list = [value for value in vocab_list if value not in odd_values]"
907
+ ]
908
+ },
909
+ {
910
+ "cell_type": "code",
911
+ "execution_count": 55,
912
+ "metadata": {},
913
+ "outputs": [
914
+ {
915
+ "data": {
916
+ "text/plain": [
917
+ "['ß',\n",
918
+ " 'j',\n",
919
+ " 'r',\n",
920
+ " 'h',\n",
921
+ " 'd',\n",
922
+ " 'l',\n",
923
+ " 'z',\n",
924
+ " 'n',\n",
925
+ " 'm',\n",
926
+ " 'c',\n",
927
+ " 'ä',\n",
928
+ " \"'\",\n",
929
+ " 'g',\n",
930
+ " 'e',\n",
931
+ " 'w',\n",
932
+ " 's',\n",
933
+ " 'u',\n",
934
+ " 'k',\n",
935
+ " 'o',\n",
936
+ " 'f',\n",
937
+ " ' ',\n",
938
+ " 'y',\n",
939
+ " 'v',\n",
940
+ " 'ö',\n",
941
+ " 'ü',\n",
942
+ " 'p',\n",
943
+ " 'a',\n",
944
+ " 'x',\n",
945
+ " 'b',\n",
946
+ " 'q',\n",
947
+ " 't',\n",
948
+ " 'i']"
949
+ ]
950
+ },
951
+ "execution_count": 55,
952
+ "metadata": {},
953
+ "output_type": "execute_result"
954
+ }
955
+ ],
956
+ "source": [
957
+ "filtered_vocab_list"
958
+ ]
959
+ },
960
+ {
961
+ "cell_type": "code",
962
+ "execution_count": 21,
963
+ "metadata": {},
964
+ "outputs": [
965
+ {
966
+ "ename": "NameError",
967
+ "evalue": "name 'word_delimiter_token' is not defined",
968
+ "output_type": "error",
969
+ "traceback": [
970
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
971
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
972
+ "Input \u001b[0;32mIn [21]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m vocab_dict \u001b[38;5;241m=\u001b[39m {v: k \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28msorted\u001b[39m(vocab_list))}\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# replace white space with delimiter token\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mword_delimiter_token\u001b[49m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 5\u001b[0m vocab_dict[word_delimiter_token] \u001b[38;5;241m=\u001b[39m vocab_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m vocab_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
973
+ "\u001b[0;31mNameError\u001b[0m: name 'word_delimiter_token' is not defined"
974
+ ]
975
+ }
976
+ ],
977
+ "source": [
978
+ "vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}\n",
979
+ "\n",
980
+ "# replace white space with delimiter token\n",
981
+ "if word_delimiter_token is not None:\n",
982
+ " vocab_dict[word_delimiter_token] = vocab_dict[\" \"]\n",
983
+ " del vocab_dict[\" \"]\n",
984
+ "\n",
985
+ "# add unk and pad token\n",
986
+ "if unk_token is not None:\n",
987
+ " vocab_dict[unk_token] = len(vocab_dict)\n",
988
+ "\n",
989
+ "if pad_token is not None:\n",
990
+ " vocab_dict[pad_token] = len(vocab_dict)"
991
+ ]
992
+ },
993
+ {
994
+ "cell_type": "code",
995
+ "execution_count": 58,
996
+ "metadata": {},
997
+ "outputs": [
998
+ {
999
+ "data": {
1000
+ "application/vnd.jupyter.widget-view+json": {
1001
+ "model_id": "59e89471ea85449ebbc709d0a9d7325c",
1002
+ "version_major": 2,
1003
+ "version_minor": 0
1004
+ },
1005
+ "text/plain": [
1006
+ " 0%| | 0/437 [00:00<?, ?ba/s]"
1007
+ ]
1008
+ },
1009
+ "metadata": {},
1010
+ "output_type": "display_data"
1011
+ },
1012
+ {
1013
+ "name": "stdout",
1014
+ "output_type": "stream",
1015
+ "text": [
1016
+ "OOV found in 421223 samples, and they were removed from training set\n",
1017
+ "The final training set size is 14947\n"
1018
+ ]
1019
+ }
1020
+ ],
1021
+ "source": [
1022
+ "vocab_set = set(filtered_vocab_list)\n",
1023
+ "train_dataset_size = len(common_voice_train)\n",
1024
+ "common_voice_train_2 = common_voice_train.filter(\n",
1025
+ " lambda example: vocab_set.issuperset(example[\"sentence\"].replace(\" \", \"\"))\n",
1026
+ ")\n",
1027
+ "print(f\"OOV found in {train_dataset_size - len(common_voice_train_2)} samples, and they were removed from training set\")\n",
1028
+ "print(f\"The final training set size is {len(common_voice_train_2)}\")"
1029
+ ]
1030
+ },
1031
+ {
1032
+ "cell_type": "code",
1033
+ "execution_count": 38,
1034
+ "metadata": {
1035
+ "collapsed": true,
1036
+ "jupyter": {
1037
+ "outputs_hidden": true
1038
+ }
1039
+ },
1040
+ "outputs": [
1041
+ {
1042
+ "ename": "KeyboardInterrupt",
1043
+ "evalue": "",
1044
+ "output_type": "error",
1045
+ "traceback": [
1046
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1047
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
1048
+ "Input \u001b[0;32mIn [38]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m odd_example_texts \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m common_voice_train:\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m letter \u001b[38;5;129;01min\u001b[39;00m odd_values:\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m letter \u001b[38;5;129;01min\u001b[39;00m row[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msentence\u001b[39m\u001b[38;5;124m\"\u001b[39m]: \n",
1049
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/arrow_dataset.py:1664\u001b[0m, in \u001b[0;36mDataset._iter\u001b[0;34m(self, decoded)\u001b[0m\n\u001b[1;32m 1658\u001b[0m \u001b[38;5;124;03m\"\"\"Iterate through the examples.\u001b[39;00m\n\u001b[1;32m 1659\u001b[0m \n\u001b[1;32m 1660\u001b[0m \u001b[38;5;124;03mIf a formatting is set with :meth:`Dataset.set_format` rows will be returned with the\u001b[39;00m\n\u001b[1;32m 1661\u001b[0m \u001b[38;5;124;03mselected format.\u001b[39;00m\n\u001b[1;32m 1662\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1663\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m index \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_rows):\n\u001b[0;32m-> 1664\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_getitem\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1665\u001b[0m \u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1666\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoded\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoded\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1667\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
1050
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/arrow_dataset.py:1915\u001b[0m, in \u001b[0;36mDataset._getitem\u001b[0;34m(self, key, decoded, **kwargs)\u001b[0m\n\u001b[1;32m 1913\u001b[0m formatter \u001b[38;5;241m=\u001b[39m get_formatter(format_type, features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfeatures, decoded\u001b[38;5;241m=\u001b[39mdecoded, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mformat_kwargs)\n\u001b[1;32m 1914\u001b[0m pa_subtable \u001b[38;5;241m=\u001b[39m query_table(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data, key, indices\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_indices \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_indices \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m-> 1915\u001b[0m formatted_output \u001b[38;5;241m=\u001b[39m \u001b[43mformat_table\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1916\u001b[0m \u001b[43m \u001b[49m\u001b[43mpa_subtable\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mformatter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mformatter\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mformat_columns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mformat_columns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_all_columns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_all_columns\u001b[49m\n\u001b[1;32m 1917\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1918\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m formatted_output\n",
1051
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/formatting/formatting.py:533\u001b[0m, in \u001b[0;36mformat_table\u001b[0;34m(table, key, formatter, format_columns, output_all_columns)\u001b[0m\n\u001b[1;32m 531\u001b[0m python_formatter \u001b[38;5;241m=\u001b[39m PythonFormatter(features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 532\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m format_columns \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 533\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mformatter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpa_table\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mquery_type\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 534\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m query_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcolumn\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 535\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m format_columns:\n",
1052
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/formatting/formatting.py:282\u001b[0m, in \u001b[0;36mFormatter.__call__\u001b[0;34m(self, pa_table, query_type)\u001b[0m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, pa_table: pa\u001b[38;5;241m.\u001b[39mTable, query_type: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[RowFormat, ColumnFormat, BatchFormat]:\n\u001b[1;32m 281\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m query_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrow\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 282\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mformat_row\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpa_table\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 283\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m query_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcolumn\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 284\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mformat_column(pa_table)\n",
1053
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/formatting/formatting.py:313\u001b[0m, in \u001b[0;36mPythonFormatter.format_row\u001b[0;34m(self, pa_table)\u001b[0m\n\u001b[1;32m 311\u001b[0m row \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpython_arrow_extractor()\u001b[38;5;241m.\u001b[39mextract_row(pa_table)\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoded:\n\u001b[0;32m--> 313\u001b[0m row \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpython_features_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode_row\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrow\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m row\n",
1054
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/formatting/formatting.py:222\u001b[0m, in \u001b[0;36mPythonFeaturesDecoder.decode_row\u001b[0;34m(self, row)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode_row\u001b[39m(\u001b[38;5;28mself\u001b[39m, row: \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mdict\u001b[39m:\n\u001b[0;32m--> 222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode_example\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrow\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfeatures \u001b[38;5;28;01melse\u001b[39;00m row\n",
1055
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/features/features.py:1318\u001b[0m, in \u001b[0;36mFeatures.decode_example\u001b[0;34m(self, example)\u001b[0m\n\u001b[1;32m 1308\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode_example\u001b[39m(\u001b[38;5;28mself\u001b[39m, example: \u001b[38;5;28mdict\u001b[39m):\n\u001b[1;32m 1309\u001b[0m \u001b[38;5;124;03m\"\"\"Decode example with custom feature decoding.\u001b[39;00m\n\u001b[1;32m 1310\u001b[0m \n\u001b[1;32m 1311\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1315\u001b[0m \u001b[38;5;124;03m :obj:`dict[str, Any]`\u001b[39;00m\n\u001b[1;32m 1316\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1318\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[1;32m 1319\u001b[0m column_name: decode_nested_example(feature, value)\n\u001b[1;32m 1320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_column_requires_decoding[column_name]\n\u001b[1;32m 1321\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m value\n\u001b[1;32m 1322\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m column_name, (feature, value) \u001b[38;5;129;01min\u001b[39;00m utils\u001b[38;5;241m.\u001b[39mzip_dict(\n\u001b[1;32m 1323\u001b[0m {key: value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m example}, example\n\u001b[1;32m 1324\u001b[0m )\n\u001b[1;32m 1325\u001b[0m }\n",
1056
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/features/features.py:1319\u001b[0m, in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 1308\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode_example\u001b[39m(\u001b[38;5;28mself\u001b[39m, example: \u001b[38;5;28mdict\u001b[39m):\n\u001b[1;32m 1309\u001b[0m \u001b[38;5;124;03m\"\"\"Decode example with custom feature decoding.\u001b[39;00m\n\u001b[1;32m 1310\u001b[0m \n\u001b[1;32m 1311\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1315\u001b[0m \u001b[38;5;124;03m :obj:`dict[str, Any]`\u001b[39;00m\n\u001b[1;32m 1316\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m 1318\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[0;32m-> 1319\u001b[0m column_name: \u001b[43mdecode_nested_example\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeature\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_column_requires_decoding[column_name]\n\u001b[1;32m 1321\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m value\n\u001b[1;32m 1322\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m column_name, (feature, value) \u001b[38;5;129;01min\u001b[39;00m utils\u001b[38;5;241m.\u001b[39mzip_dict(\n\u001b[1;32m 1323\u001b[0m {key: value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m example}, example\n\u001b[1;32m 1324\u001b[0m )\n\u001b[1;32m 1325\u001b[0m }\n",
1057
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/features/features.py:1056\u001b[0m, in \u001b[0;36mdecode_nested_example\u001b[0;34m(schema, obj)\u001b[0m\n\u001b[1;32m 1054\u001b[0m \u001b[38;5;66;03m# Object with special decoding:\u001b[39;00m\n\u001b[1;32m 1055\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(schema, (Audio, Image)):\n\u001b[0;32m-> 1056\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mschema\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode_example\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mif\u001b[39;00m obj \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1057\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m obj\n",
1058
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/features/audio.py:97\u001b[0m, in \u001b[0;36mAudio.decode_example\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAn audio sample should have one of \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpath\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m or \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbytes\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m but both are None in \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mvalue\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m path \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m path\u001b[38;5;241m.\u001b[39mendswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmp3\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m---> 97\u001b[0m array, sampling_rate \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_decode_mp3\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfile\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 98\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file:\n",
1059
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/datasets/features/audio.py:183\u001b[0m, in \u001b[0;36mAudio._decode_mp3\u001b[0;34m(self, path_or_file)\u001b[0m\n\u001b[1;32m 181\u001b[0m array \u001b[38;5;241m=\u001b[39m array\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmono:\n\u001b[0;32m--> 183\u001b[0m array \u001b[38;5;241m=\u001b[39m \u001b[43marray\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m array, sampling_rate\n",
1060
+ "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/numpy/core/_methods.py:154\u001b[0m, in \u001b[0;36m_mean\u001b[0;34m(a, axis, dtype, out, keepdims)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[38;5;66;03m# Cast bool, unsigned int, and int to float64 by default\u001b[39;00m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 154\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28missubclass\u001b[39m(arr\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;241m.\u001b[39mtype, (\u001b[43mnt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minteger\u001b[49m, nt\u001b[38;5;241m.\u001b[39mbool_)):\n\u001b[1;32m 155\u001b[0m dtype \u001b[38;5;241m=\u001b[39m mu\u001b[38;5;241m.\u001b[39mdtype(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mf8\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28missubclass\u001b[39m(arr\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;241m.\u001b[39mtype, nt\u001b[38;5;241m.\u001b[39mfloat16):\n",
1061
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
1062
+ ]
1063
+ }
1064
+ ],
1065
+ "source": []
1066
+ },
1067
+ {
1068
+ "cell_type": "code",
1069
+ "execution_count": null,
1070
+ "metadata": {},
1071
+ "outputs": [],
1072
+ "source": []
1073
+ },
1074
+ {
1075
+ "cell_type": "code",
1076
+ "execution_count": null,
1077
+ "metadata": {},
1078
+ "outputs": [],
1079
+ "source": []
1080
+ },
1081
+ {
1082
+ "cell_type": "code",
1083
+ "execution_count": 43,
1084
+ "metadata": {},
1085
+ "outputs": [
1086
+ {
1087
+ "data": {
1088
+ "text/plain": [
1089
+ "0"
1090
+ ]
1091
+ },
1092
+ "execution_count": 43,
1093
+ "metadata": {},
1094
+ "output_type": "execute_result"
1095
+ }
1096
+ ],
1097
+ "source": []
1098
+ },
1099
+ {
1100
+ "cell_type": "code",
1101
+ "execution_count": 22,
1102
+ "metadata": {},
1103
+ "outputs": [
1104
+ {
1105
+ "name": "stdout",
1106
+ "output_type": "stream",
1107
+ "text": [
1108
+ "0 \n",
1109
+ "1 &\n",
1110
+ "2 '\n",
1111
+ "3 .\n",
1112
+ "4 /\n",
1113
+ "5 A\n",
1114
+ "6 B\n",
1115
+ "7 C\n",
1116
+ "8 D\n",
1117
+ "9 E\n",
1118
+ "10 F\n",
1119
+ "11 G\n",
1120
+ "12 H\n",
1121
+ "13 I\n",
1122
+ "14 J\n",
1123
+ "15 K\n",
1124
+ "16 L\n",
1125
+ "17 M\n",
1126
+ "18 N\n",
1127
+ "19 O\n",
1128
+ "20 P\n",
1129
+ "21 Q\n",
1130
+ "22 R\n",
1131
+ "23 S\n",
1132
+ "24 T\n",
1133
+ "25 U\n",
1134
+ "26 V\n",
1135
+ "27 W\n",
1136
+ "28 X\n",
1137
+ "29 Y\n",
1138
+ "30 Z\n",
1139
+ "31 a\n",
1140
+ "32 b\n",
1141
+ "33 c\n",
1142
+ "34 d\n",
1143
+ "35 e\n",
1144
+ "36 f\n",
1145
+ "37 g\n",
1146
+ "38 h\n",
1147
+ "39 i\n",
1148
+ "40 j\n",
1149
+ "41 k\n",
1150
+ "42 l\n",
1151
+ "43 m\n",
1152
+ "44 n\n",
1153
+ "45 o\n",
1154
+ "46 p\n",
1155
+ "47 q\n",
1156
+ "48 r\n",
1157
+ "49 s\n",
1158
+ "50 t\n",
1159
+ "51 u\n",
1160
+ "52 v\n",
1161
+ "53 w\n",
1162
+ "54 x\n",
1163
+ "55 y\n",
1164
+ "56 z\n",
1165
+ "57 ―\n",
1166
+ "58 、\n",
1167
+ "59 。\n",
1168
+ "60 々\n",
1169
+ "61 〇\n",
1170
+ "62 「\n",
1171
+ "63 」\n",
1172
+ "64 『\n",
1173
+ "65 』\n",
1174
+ "66 〜\n",
1175
+ "67 ぁ\n",
1176
+ "68 あ\n",
1177
+ "69 い\n",
1178
+ "70 う\n",
1179
+ "71 ぇ\n",
1180
+ "72 え\n",
1181
+ "73 お\n",
1182
+ "74 か\n",
1183
+ "75 が\n",
1184
+ "76 き\n",
1185
+ "77 ぎ\n",
1186
+ "78 く\n",
1187
+ "79 ぐ\n",
1188
+ "80 け\n",
1189
+ "81 げ\n",
1190
+ "82 こ\n",
1191
+ "83 ご\n",
1192
+ "84 さ\n",
1193
+ "85 ざ\n",
1194
+ "86 し\n",
1195
+ "87 じ\n",
1196
+ "88 す\n",
1197
+ "89 ず\n",
1198
+ "90 せ\n",
1199
+ "91 ぜ\n",
1200
+ "92 そ\n",
1201
+ "93 ぞ\n",
1202
+ "94 た\n",
1203
+ "95 だ\n",
1204
+ "96 ち\n",
1205
+ "97 ぢ\n",
1206
+ "98 っ\n",
1207
+ "99 つ\n",
1208
+ "100 づ\n",
1209
+ "101 て\n",
1210
+ "102 で\n",
1211
+ "103 と\n",
1212
+ "104 ど\n",
1213
+ "105 な\n",
1214
+ "106 に\n",
1215
+ "107 ぬ\n",
1216
+ "108 ね\n",
1217
+ "109 の\n",
1218
+ "110 は\n",
1219
+ "111 ば\n",
1220
+ "112 ぱ\n",
1221
+ "113 ひ\n",
1222
+ "114 び\n",
1223
+ "115 ぴ\n",
1224
+ "116 ふ\n",
1225
+ "117 ぶ\n",
1226
+ "118 ぷ\n",
1227
+ "119 へ\n",
1228
+ "120 べ\n",
1229
+ "121 ぺ\n",
1230
+ "122 ほ\n",
1231
+ "123 ぼ\n",
1232
+ "124 ぽ\n",
1233
+ "125 ま\n",
1234
+ "126 み\n",
1235
+ "127 む\n",
1236
+ "128 め\n",
1237
+ "129 も\n",
1238
+ "130 ゃ\n",
1239
+ "131 や\n",
1240
+ "132 ゅ\n",
1241
+ "133 ゆ\n",
1242
+ "134 ょ\n",
1243
+ "135 よ\n",
1244
+ "136 ら\n",
1245
+ "137 り\n",
1246
+ "138 る\n",
1247
+ "139 れ\n",
1248
+ "140 ろ\n",
1249
+ "141 わ\n",
1250
+ "142 を\n",
1251
+ "143 ん\n",
1252
+ "144 ァ\n",
1253
+ "145 ア\n",
1254
+ "146 ィ\n",
1255
+ "147 イ\n",
1256
+ "148 ゥ\n",
1257
+ "149 ウ\n",
1258
+ "150 ェ\n",
1259
+ "151 エ\n",
1260
+ "152 ォ\n",
1261
+ "153 オ\n",
1262
+ "154 カ\n",
1263
+ "155 ガ\n",
1264
+ "156 キ\n",
1265
+ "157 ギ\n",
1266
+ "158 ク\n",
1267
+ "159 グ\n",
1268
+ "160 ケ\n",
1269
+ "161 ゲ\n",
1270
+ "162 コ\n",
1271
+ "163 ゴ\n",
1272
+ "164 サ\n",
1273
+ "165 ザ\n",
1274
+ "166 シ\n",
1275
+ "167 ジ\n",
1276
+ "168 ス\n",
1277
+ "169 ズ\n",
1278
+ "170 セ\n",
1279
+ "171 ゼ\n",
1280
+ "172 ソ\n",
1281
+ "173 ゾ\n",
1282
+ "174 タ\n",
1283
+ "175 ダ\n",
1284
+ "176 チ\n",
1285
+ "177 ッ\n",
1286
+ "178 ツ\n",
1287
+ "179 ヅ\n",
1288
+ "180 テ\n",
1289
+ "181 デ\n",
1290
+ "182 ト\n",
1291
+ "183 ド\n",
1292
+ "184 ナ\n",
1293
+ "185 ニ\n",
1294
+ "186 ヌ\n",
1295
+ "187 ネ\n",
1296
+ "188 ノ\n",
1297
+ "189 ハ\n",
1298
+ "190 バ\n",
1299
+ "191 パ\n",
1300
+ "192 ヒ\n",
1301
+ "193 ビ\n",
1302
+ "194 ピ\n",
1303
+ "195 フ\n",
1304
+ "196 ブ\n",
1305
+ "197 プ\n",
1306
+ "198 ヘ\n",
1307
+ "199 ベ\n",
1308
+ "200 ペ\n",
1309
+ "201 ホ\n",
1310
+ "202 ボ\n",
1311
+ "203 ポ\n",
1312
+ "204 マ\n",
1313
+ "205 ミ\n",
1314
+ "206 ム\n",
1315
+ "207 メ\n",
1316
+ "208 モ\n",
1317
+ "209 ャ\n",
1318
+ "210 ヤ\n",
1319
+ "211 ュ\n",
1320
+ "212 ユ\n",
1321
+ "213 ョ\n",
1322
+ "214 ヨ\n",
1323
+ "215 ラ\n",
1324
+ "216 リ\n",
1325
+ "217 ル\n",
1326
+ "218 レ\n",
1327
+ "219 ロ\n",
1328
+ "220 ワ\n",
1329
+ "221 ン\n",
1330
+ "222 ヴ\n",
1331
+ "223 ヶ\n",
1332
+ "224 ・\n",
1333
+ "225 ー\n",
1334
+ "226 繫\n",
1335
+ "227 !\n",
1336
+ "228 &\n",
1337
+ "229 )\n",
1338
+ "230 -\n",
1339
+ "231 .\n",
1340
+ "232 :\n",
1341
+ "233 =\n",
1342
+ "234 ?\n",
1343
+ "235 A\n",
1344
+ "236 D\n",
1345
+ "237 F\n",
1346
+ "238 G\n",
1347
+ "239 N\n",
1348
+ "240 O\n",
1349
+ "241 P\n",
1350
+ "242 S\n",
1351
+ "243 U\n",
1352
+ "244 h\n",
1353
+ "245 j\n",
1354
+ "246 「\n",
1355
+ "247 」\n",
1356
+ "248 ・\n"
1357
+ ]
1358
+ }
1359
+ ],
1360
+ "source": [
1361
+ "vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}\n",
1362
+ "for key, value in enumerate(vocab_dict):\n",
1363
+ " print(key, value)"
1364
+ ]
1365
+ },
1366
+ {
1367
+ "cell_type": "code",
1368
+ "execution_count": null,
1369
+ "metadata": {},
1370
+ "outputs": [],
1371
+ "source": [
1372
+ "def create_vocabulary_from_data(\n",
1373
+ " datasets: DatasetDict,\n",
1374
+ " word_delimiter_token: Optional[str] = None,\n",
1375
+ " unk_token: Optional[str] = None,\n",
1376
+ " pad_token: Optional[str] = None,\n",
1377
+ "):\n",
1378
+ " # Given training and test labels create vocabulary\n",
1379
+ " def extract_all_chars(batch):\n",
1380
+ " all_text = \" \".join(batch[\"target_text\"])\n",
1381
+ " vocab = list(set(all_text))\n",
1382
+ " return {\"vocab\": [vocab], \"all_text\": [all_text]}\n",
1383
+ "\n",
1384
+ " vocabs = datasets.map(\n",
1385
+ " extract_all_chars,\n",
1386
+ " batched=True,\n",
1387
+ " batch_size=-1,\n",
1388
+ " keep_in_memory=True,\n",
1389
+ " remove_columns=datasets[\"train\"].column_names,\n",
1390
+ " )\n",
1391
+ "\n",
1392
+ " # take union of all unique characters in each dataset\n",
1393
+ " vocab_set = functools.reduce(\n",
1394
+ " lambda vocab_1, vocab_2: set(vocab_1[\"vocab\"][0]) | set(vocab_2[\"vocab\"][0]), vocabs.values()\n",
1395
+ " )\n",
1396
+ "\n",
1397
+ " vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}\n",
1398
+ "\n",
1399
+ " # replace white space with delimiter token\n",
1400
+ " if word_delimiter_token is not None:\n",
1401
+ " vocab_dict[word_delimiter_token] = vocab_dict[\" \"]\n",
1402
+ " del vocab_dict[\" \"]\n",
1403
+ "\n",
1404
+ " # add unk and pad token\n",
1405
+ " if unk_token is not None:\n",
1406
+ " vocab_dict[unk_token] = len(vocab_dict)\n",
1407
+ "\n",
1408
+ " if pad_token is not None:\n",
1409
+ " vocab_dict[pad_token] = len(vocab_dict)\n",
1410
+ "\n",
1411
+ " return vocab_dict"
1412
+ ]
1413
+ },
1414
+ {
1415
+ "cell_type": "code",
1416
+ "execution_count": null,
1417
+ "metadata": {},
1418
+ "outputs": [],
1419
+ "source": []
1420
+ },
1421
+ {
1422
+ "cell_type": "code",
1423
+ "execution_count": null,
1424
+ "metadata": {},
1425
+ "outputs": [],
1426
+ "source": []
1427
+ },
1428
+ {
1429
+ "cell_type": "code",
1430
+ "execution_count": null,
1431
+ "metadata": {},
1432
+ "outputs": [],
1433
+ "source": []
1434
+ },
1435
+ {
1436
+ "cell_type": "code",
1437
+ "execution_count": null,
1438
+ "metadata": {},
1439
+ "outputs": [],
1440
+ "source": [
1441
+ "# load processor\n",
1442
+ "feature_extractor = AutoFeatureExtractor.from_pretrained(repo_name)\n",
1443
+ "# feature_extractor = processor_with_lm.feature_extractor\n",
1444
+ "sampling_rate = feature_extractor.sampling_rate\n",
1445
+ "\n",
1446
+ "# resample audio\n",
1447
+ "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=sampling_rate))\n",
1448
+ "\n",
1449
+ "# load eval pipeline\n",
1450
+ "asr = pipeline(\"automatic-speech-recognition\", model=repo_name, feature_extractor=feature_extractor)\n",
1451
+ "\n",
1452
+ "# map function to decode audio\n",
1453
+ "def map_to_pred(batch):\n",
1454
+ " prediction = asr(\n",
1455
+ " batch[\"audio\"][\"array\"])\n",
1456
+ "\n",
1457
+ " batch[\"prediction\"] = prediction[\"text\"]\n",
1458
+ " batch[\"target\"] = batch[\"sentence\"]\n",
1459
+ " return batch\n",
1460
+ "\n",
1461
+ "# run inference on all examples\n",
1462
+ "result = dataset.map(map_to_pred, remove_columns=dataset.column_names)\n",
1463
+ "print(result[\"prediction\"])\n",
1464
+ "\n",
1465
+ "result[0]['target']"
1466
+ ]
1467
+ }
1468
+ ],
1469
+ "metadata": {
1470
+ "kernelspec": {
1471
+ "display_name": "Python 3 (ipykernel)",
1472
+ "language": "python",
1473
+ "name": "python3"
1474
+ },
1475
+ "language_info": {
1476
+ "codemirror_mode": {
1477
+ "name": "ipython",
1478
+ "version": 3
1479
+ },
1480
+ "file_extension": ".py",
1481
+ "mimetype": "text/x-python",
1482
+ "name": "python",
1483
+ "nbconvert_exporter": "python",
1484
+ "pygments_lexer": "ipython3",
1485
+ "version": "3.8.8"
1486
+ }
1487
+ },
1488
+ "nbformat": 4,
1489
+ "nbformat_minor": 4
1490
+ }
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "[PAD]", "do_lower_case": false, "word_delimiter_token": "|", "special_tokens_map_file": null, "tokenizer_file": null, "name_or_path": "./", "tokenizer_class": "Wav2Vec2CTCTokenizer"}
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70e27818f3fc71ffdfcc80419d1967fd61208e9dc6b1b3d61fd6629f0946734b
3
+ size 2991
vocab.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"&": 1, "'": 2, ".": 3, "/": 4, "A": 5, "B": 6, "C": 7, "D": 8, "E": 9, "F": 10, "G": 11, "H": 12, "I": 13, "J": 14, "K": 15, "L": 16, "M": 17, "N": 18, "O": 19, "P": 20, "Q": 21, "R": 22, "S": 23, "T": 24, "U": 25, "V": 26, "W": 27, "X": 28, "Y": 29, "Z": 30, "a": 31, "b": 32, "c": 33, "d": 34, "e": 35, "f": 36, "g": 37, "h": 38, "i": 39, "j": 40, "k": 41, "l": 42, "m": 43, "n": 44, "o": 45, "p": 46, "q": 47, "r": 48, "s": 49, "t": 50, "u": 51, "v": 52, "w": 53, "x": 54, "y": 55, "z": 56, "―": 57, "、": 58, "。": 59, "々": 60, "〇": 61, "「": 62, "」": 63, "『": 64, "』": 65, "〜": 66, "ぁ": 67, "あ": 68, "い": 69, "う": 70, "ぇ": 71, "え": 72, "お": 73, "か": 74, "が": 75, "き": 76, "ぎ": 77, "く": 78, "ぐ": 79, "け": 80, "げ": 81, "こ": 82, "ご": 83, "さ": 84, "ざ": 85, "し": 86, "じ": 87, "す": 88, "ず": 89, "せ": 90, "ぜ": 91, "そ": 92, "ぞ": 93, "た": 94, "だ": 95, "ち": 96, "ぢ": 97, "っ": 98, "つ": 99, "づ": 100, "て": 101, "で": 102, "と": 103, "ど": 104, "な": 105, "に": 106, "ぬ": 107, "ね": 108, "の": 109, "は": 110, "ば": 111, "ぱ": 112, "ひ": 113, "び": 114, "ぴ": 115, "ふ": 116, "ぶ": 117, "ぷ": 118, "へ": 119, "べ": 120, "ぺ": 121, "ほ": 122, "ぼ": 123, "ぽ": 124, "ま": 125, "み": 126, "む": 127, "め": 128, "も": 129, "ゃ": 130, "や": 131, "ゅ": 132, "ゆ": 133, "ょ": 134, "よ": 135, "ら": 136, "り": 137, "る": 138, "れ": 139, "ろ": 140, "わ": 141, "を": 142, "ん": 143, "ァ": 144, "ア": 145, "ィ": 146, "イ": 147, "ゥ": 148, "ウ": 149, "ェ": 150, "エ": 151, "ォ": 152, "オ": 153, "カ": 154, "ガ": 155, "キ": 156, "ギ": 157, "ク": 158, "グ": 159, "ケ": 160, "ゲ": 161, "コ": 162, "ゴ": 163, "サ": 164, "ザ": 165, "シ": 166, "ジ": 167, "ス": 168, "ズ": 169, "セ": 170, "ゼ": 171, "ソ": 172, "ゾ": 173, "タ": 174, "ダ": 175, "チ": 176, "ッ": 177, "ツ": 178, "ヅ": 179, "テ": 180, "デ": 181, "ト": 182, "ド": 183, "ナ": 184, "ニ": 185, "ヌ": 186, "ネ": 187, "ノ": 188, "ハ": 189, "バ": 190, "パ": 191, "ヒ": 192, "ビ": 193, "ピ": 194, "フ": 195, "ブ": 196, "プ": 197, "ヘ": 198, "ベ": 199, "ペ": 200, "ホ": 201, "ボ": 202, "ポ": 203, "マ": 204, "ミ": 205, "ム": 206, "メ": 207, "モ": 208, "ャ": 209, "ヤ": 210, "ュ": 211, "ユ": 212, "ョ": 213, "ヨ": 214, "ラ": 215, "リ": 216, "ル": 217, "レ": 218, "ロ": 219, "ワ": 220, "ン": 221, "ヴ": 222, "ヶ": 223, "・": 224, "ー": 225, "繫": 226, "&": 227, ")": 228, "-": 229, ".": 230, ":": 231, "=": 232, "?": 233, "A": 234, "D": 235, "F": 236, "G": 237, "N": 238, "O": 239, "P": 240, "S": 241, "U": 242, "h": 243, "j": 244, "「": 245, "」": 246, "・": 247, "|": 0, "[UNK]": 248, "[PAD]": 249}