pere commited on
Commit
d234eb9
1 Parent(s): bdb2ded
Files changed (2) hide show
  1. run_test.sh +1 -1
  2. run_whisper_finetuning.py +30 -30
run_test.sh CHANGED
@@ -4,7 +4,7 @@
4
 
5
  python run_whisper_finetuning.py \
6
  --model_name_or_path="openai/whisper-small" \
7
- --output_dir="../whisper-test-delete2" \
8
  --overwrite_output_dir=True \
9
  --language="Norwegian" \
10
  --task="transcribe" \
 
4
 
5
  python run_whisper_finetuning.py \
6
  --model_name_or_path="openai/whisper-small" \
7
+ --output_dir="../whisper-test-delete3" \
8
  --overwrite_output_dir=True \
9
  --language="Norwegian" \
10
  --task="transcribe" \
run_whisper_finetuning.py CHANGED
@@ -92,7 +92,7 @@ class Seq2SeqTrainingArguments(TrainingArguments):
92
  )
93
  },
94
  )
95
- xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
96
 
97
  @dataclass
98
  class ModelArguments:
@@ -340,10 +340,6 @@ def main():
340
  parser = HfArgumentParser(
341
  (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
342
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
343
-
344
- #Debug
345
- import torch_xla.debug.metrics as met
346
- print(met.metrics_report())
347
 
348
  # Metrics
349
 
@@ -390,14 +386,14 @@ def main():
390
  feats[new_name] = feats.pop(old_name)
391
  ds.info.features = feats
392
  return ds
393
-
394
  def remove_columns(ds, column_name):
395
- feats = ds.info.features
396
- ds = ds.remove_columns(column_name)
397
- feats.pop(column_name)
398
- ds.info.features = feats
399
- return ds
400
-
401
  # Print training arguments
402
  if data_args.print_training_arguments:
403
  print_training_arguments(model_args, data_args, training_args)
@@ -409,12 +405,12 @@ def main():
409
 
410
  # Rename columns
411
  if data_args.audio_column_name != "audio":
412
- train_dataset = rename_column(train_dataset,data_args.audio_column_name, "audio")
413
- eval_dataset = rename_column(eval_dataset,data_args.audio_column_name, "audio")
414
 
415
  if data_args.text_column_name != "sentence":
416
- train_dataset = rename_column(train_dataset,data_args.text_column_name, "sentence")
417
- eval_dataset = rename_column(eval_dataset,data_args.text_column_name, "sentence")
418
 
419
 
420
  # Initialise
@@ -429,23 +425,27 @@ def main():
429
  # Saving the processor and the tokenizer
430
  processor.save_pretrained(training_args.output_dir)
431
  tokenizer.save_pretrained(training_args.output_dir)
 
 
 
 
 
 
 
432
 
433
 
434
- # Prepare data
435
- train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=16000))
436
- eval_dataset = eval_dataset.cast_column("audio", Audio(sampling_rate=16000))
 
 
 
 
437
 
438
-
439
- # Remove non needed columns
440
- column_names=[x for x in train_dataset.info.features]
441
 
442
- for c in column_names:
443
- if c not in ["audio", "sentence"]:
444
- print(f"removing {c}")
445
- train_dataset = remove_columns(train_dataset, c)
446
- eval_dataset = remove_columns(eval_dataset, c)
447
-
448
- # Prepare dataset
449
  train_dataset = train_dataset.map(prepare_dataset)
450
  eval_dataset = eval_dataset.map(prepare_dataset)
451
 
@@ -502,7 +502,7 @@ def main():
502
  # Instantaneous batch size per device = 48
503
 
504
 
505
- # TODO Add option for constant learning rate
506
  trainer = Seq2SeqTrainer(
507
  args=training_args,
508
  model=model,
 
92
  )
93
  },
94
  )
95
+
96
 
97
  @dataclass
98
  class ModelArguments:
 
340
  parser = HfArgumentParser(
341
  (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
342
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
 
 
 
 
343
 
344
  # Metrics
345
 
 
386
  feats[new_name] = feats.pop(old_name)
387
  ds.info.features = feats
388
  return ds
389
+
390
  def remove_columns(ds, column_name):
391
+ feats = ds.info.features
392
+ ds = ds.remove_columns(column_name)
393
+ feats.pop(column_name)
394
+ ds.info.features = feats
395
+ return ds
396
+
397
  # Print training arguments
398
  if data_args.print_training_arguments:
399
  print_training_arguments(model_args, data_args, training_args)
 
405
 
406
  # Rename columns
407
  if data_args.audio_column_name != "audio":
408
+ train_dataset = train_dataset.rename_column(data_args.audio_column_name, "audio")
409
+ eval_dataset = eval_dataset.rename_column(data_args.audio_column_name, "audio")
410
 
411
  if data_args.text_column_name != "sentence":
412
+ train_dataset = train_dataset.rename_column(data_args.text_column_name, "sentence")
413
+ eval_dataset = eval_dataset.rename_column(data_args.text_column_name, "sentence")
414
 
415
 
416
  # Initialise
 
425
  # Saving the processor and the tokenizer
426
  processor.save_pretrained(training_args.output_dir)
427
  tokenizer.save_pretrained(training_args.output_dir)
428
+
429
+ # Prepare data
430
+ # TODO The casting of the not working on the NPSC in 48K. It seems to be working for Common Voice
431
+ # The issue is that the dataset features returns None. But for me thay seem to have been set correctly
432
+ # In our case this is not needed, since the datasets already is available as 16K. But it would be great to solve this bug
433
+ # train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=16000))
434
+ # eval_dataset = eval_dataset.cast_column("audio", Audio(sampling_rate=16000))
435
 
436
 
437
+ # Remove non needed columns
438
+ #column_names=[x for x in train_dataset.info.features]
439
+
440
+ #for c in column_names:
441
+ # if c not in ["audio", "text"]:
442
+ # train_dataset = remove_columns(train_dataset, c)
443
+ # eval_dataset = remove_columns(eval_dataset, c)
444
 
445
+ # TODO I would really like to remove the non needed columns here. At least this cleans up the output.
446
+ # I am unable to figure out how to do this Streaming mode. Can not find a way to list columns.
447
+ # train_data = train_data.map(prepare_dataset, remove_columns=train_data.column_names, num_proc=1)
448
 
 
 
 
 
 
 
 
449
  train_dataset = train_dataset.map(prepare_dataset)
450
  eval_dataset = eval_dataset.map(prepare_dataset)
451
 
 
502
  # Instantaneous batch size per device = 48
503
 
504
 
505
+
506
  trainer = Seq2SeqTrainer(
507
  args=training_args,
508
  model=model,