Rolv-Arild commited on
Commit
fcff61b
1 Parent(s): eec3f65

Add bandaid for empty strings

Browse files
Files changed (1) hide show
  1. run_speech_recognition_seq2seq.py +12 -9
run_speech_recognition_seq2seq.py CHANGED
@@ -46,7 +46,6 @@ from transformers.trainer_utils import get_last_checkpoint, is_main_process
46
  from transformers.utils import check_min_version
47
  from transformers.utils.versions import require_version
48
 
49
-
50
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
51
  check_min_version("4.17.0.dev0")
52
 
@@ -89,7 +88,7 @@ class ModelArguments:
89
  default=False,
90
  metadata={
91
  "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
92
- "with private models)."
93
  },
94
  )
95
  freeze_feature_encoder: bool = field(
@@ -124,14 +123,14 @@ class DataTrainingArguments:
124
  default=None,
125
  metadata={
126
  "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
127
- "value if set."
128
  },
129
  )
130
  max_eval_samples: Optional[int] = field(
131
  default=None,
132
  metadata={
133
  "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
134
- "value if set."
135
  },
136
  )
137
  audio_column_name: str = field(
@@ -155,9 +154,9 @@ class DataTrainingArguments:
155
  default=False,
156
  metadata={
157
  "help": "Whether to only do data preprocessing and skip training. "
158
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
159
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
160
- "so that the cached datasets can consequently be loaded in distributed training"
161
  },
162
  )
163
  train_split_name: str = field(
@@ -283,12 +282,14 @@ def main():
283
 
284
  if training_args.do_train:
285
  raw_datasets["train"] = load_dataset(
286
- data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name, cache_dir=data_args.data_cache_dir
 
287
  )
288
 
289
  if training_args.do_eval:
290
  raw_datasets["eval"] = load_dataset(
291
- data_args.dataset_name, data_args.dataset_config_name, split=data_args.eval_split_name, cache_dir=data_args.data_cache_dir
 
292
  )
293
 
294
  if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
@@ -378,6 +379,8 @@ def main():
378
  input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
379
 
380
  input_str = re.sub(r"<\*?(ee|qq|mm|inaudible)>", "", input_str, re.IGNORECASE)
 
 
381
 
382
  batch["labels"] = tokenizer(input_str).input_ids
383
  return batch
46
  from transformers.utils import check_min_version
47
  from transformers.utils.versions import require_version
48
 
 
49
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
50
  check_min_version("4.17.0.dev0")
51
 
88
  default=False,
89
  metadata={
90
  "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
91
+ "with private models)."
92
  },
93
  )
94
  freeze_feature_encoder: bool = field(
123
  default=None,
124
  metadata={
125
  "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
126
+ "value if set."
127
  },
128
  )
129
  max_eval_samples: Optional[int] = field(
130
  default=None,
131
  metadata={
132
  "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
133
+ "value if set."
134
  },
135
  )
136
  audio_column_name: str = field(
154
  default=False,
155
  metadata={
156
  "help": "Whether to only do data preprocessing and skip training. "
157
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
158
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
159
+ "so that the cached datasets can consequently be loaded in distributed training"
160
  },
161
  )
162
  train_split_name: str = field(
282
 
283
  if training_args.do_train:
284
  raw_datasets["train"] = load_dataset(
285
+ data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name,
286
+ cache_dir=data_args.data_cache_dir
287
  )
288
 
289
  if training_args.do_eval:
290
  raw_datasets["eval"] = load_dataset(
291
+ data_args.dataset_name, data_args.dataset_config_name, split=data_args.eval_split_name,
292
+ cache_dir=data_args.data_cache_dir
293
  )
294
 
295
  if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
379
  input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
380
 
381
  input_str = re.sub(r"<\*?(ee|qq|mm|inaudible)>", "", input_str, re.IGNORECASE)
382
+ if len(input_str) == 0:
383
+ input_str = "." # bandaid
384
 
385
  batch["labels"] = tokenizer(input_str).input_ids
386
  return batch