marinone94 commited on
Commit
27673c0
1 Parent(s): 5920347

fix casting single dataset

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -363,6 +363,41 @@ def notify_me(recipient, message=None):
363
  smtp_obj.quit()
364
 
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  def load_maybe_streaming_dataset(
367
  dataset_names,
368
  dataset_config_names,
@@ -393,34 +428,16 @@ def load_maybe_streaming_dataset(
393
  dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
394
  else:
395
  dataset = load_dataset(dataset_name, split=split_name, streaming=streaming, **kwargs)
396
- raw_datasets_features = list(dataset.features.keys())
397
- logger.info(f"Dataset {dataset_name} - Features: {raw_datasets_features}")
398
- if text_col_name_ref not in raw_datasets_features:
399
- if len(text_column_names) == 1:
400
- raise ValueError("None of the text column names provided found in dataset."
401
- f"Text columns: {text_column_names}"
402
- f"Dataset columns: {raw_datasets_features}")
403
- flag = False
404
- for text_column_name in text_column_names:
405
- if text_column_name in raw_datasets_features:
406
- logger.info(f"Renaming text column {text_column_name} to {text_col_name_ref}")
407
- dataset = dataset.rename_column(text_column_name, text_col_name_ref)
408
- flag = True
409
- break
410
- if flag is False:
411
- raise ValueError("None of the text column names provided found in dataset."
412
- f"Text columns: {text_column_names}"
413
- f"Dataset columns: {raw_datasets_features}")
414
- if audio_column_name is not None and sampling_rate is not None:
415
- ds_sr = int(dataset.features[audio_column_name].sampling_rate)
416
- if ds_sr != sampling_rate:
417
- dataset = dataset.cast_column(
418
- audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
419
- )
420
- raw_datasets_features = list(dataset.features.keys())
421
- raw_datasets_features.remove(audio_column_name)
422
- raw_datasets_features.remove(text_col_name_ref)
423
- dataset = dataset.remove_columns(column_names=raw_datasets_features)
424
  dataset_splits.append(dataset)
425
 
426
  # interleave multiple splits to form one dataset
@@ -428,7 +445,16 @@ def load_maybe_streaming_dataset(
428
  return interleaved_dataset
429
  else:
430
  # load a single split *with* streaming mode
 
431
  dataset = load_dataset(dataset_names, dataset_config_names, split=split, streaming=streaming, **kwargs)
 
 
 
 
 
 
 
 
432
  return dataset
433
 
434
 
 
363
  smtp_obj.quit()
364
 
365
 
366
+ def rename_col_and_resample(dataset, dataset_name, text_column_names, text_col_name_ref, audio_column_name, sampling_rate):
367
+ raw_datasets_features = list(dataset.features.keys())
368
+ logger.info(f"Dataset {dataset_name} - Features: {raw_datasets_features}")
369
+
370
+ if text_col_name_ref not in raw_datasets_features:
371
+ if len(text_column_names) == 1:
372
+ raise ValueError("None of the text column names provided found in dataset."
373
+ f"Text columns: {text_column_names}"
374
+ f"Dataset columns: {raw_datasets_features}")
375
+ flag = False
376
+ for text_column_name in text_column_names:
377
+ if text_column_name in raw_datasets_features:
378
+ logger.info(f"Renaming text column {text_column_name} to {text_col_name_ref}")
379
+ dataset = dataset.rename_column(text_column_name, text_col_name_ref)
380
+ flag = True
381
+ break
382
+ if flag is False:
383
+ raise ValueError("None of the text column names provided found in dataset."
384
+ f"Text columns: {text_column_names}"
385
+ f"Dataset columns: {raw_datasets_features}")
386
+ if audio_column_name is not None and sampling_rate is not None:
387
+ ds_sr = int(dataset.features[audio_column_name].sampling_rate)
388
+ if ds_sr != sampling_rate:
389
+ dataset = dataset.cast_column(
390
+ audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
391
+ )
392
+
393
+ raw_datasets_features = list(dataset.features.keys())
394
+ raw_datasets_features.remove(audio_column_name)
395
+ raw_datasets_features.remove(text_col_name_ref)
396
+ # Keep only audio and sentence
397
+ dataset = dataset.remove_columns(column_names=raw_datasets_features)
398
+ return dataset
399
+
400
+
401
  def load_maybe_streaming_dataset(
402
  dataset_names,
403
  dataset_config_names,
 
428
  dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
429
  else:
430
  dataset = load_dataset(dataset_name, split=split_name, streaming=streaming, **kwargs)
431
+
432
+ dataset = rename_col_and_resample(
433
+ dataset,
434
+ dataset_name,
435
+ text_column_names,
436
+ text_col_name_ref,
437
+ audio_column_name,
438
+ sampling_rate
439
+ )
440
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  dataset_splits.append(dataset)
442
 
443
  # interleave multiple splits to form one dataset
 
445
  return interleaved_dataset
446
  else:
447
  # load a single split *with* streaming mode
448
+
449
  dataset = load_dataset(dataset_names, dataset_config_names, split=split, streaming=streaming, **kwargs)
450
+ dataset = rename_col_and_resample(
451
+ dataset,
452
+ dataset_names,
453
+ text_column_names,
454
+ text_col_name_ref,
455
+ audio_column_name,
456
+ sampling_rate
457
+ )
458
  return dataset
459
 
460