nbroad HF staff commited on
Commit
ab87e85
1 Parent(s): 1c58869

Upload run_summarization_flax.py

Browse files
Files changed (1) hide show
  1. run_summarization_flax.py +19 -17
run_summarization_flax.py CHANGED
@@ -431,23 +431,25 @@ def main():
431
  return
432
 
433
  # Get the column names for input/target.
434
- dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
435
- if data_args.text_column is None:
436
- text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
437
- else:
438
- text_column = data_args.text_column
439
- if text_column not in column_names:
440
- raise ValueError(
441
- f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
442
- )
443
- if data_args.summary_column is None:
444
- summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
445
- else:
446
- summary_column = data_args.summary_column
447
- if summary_column not in column_names:
448
- raise ValueError(
449
- f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
450
- )
 
 
451
 
452
  # Temporarily set max_target_length for training.
453
  max_target_length = data_args.max_target_length
 
431
  return
432
 
433
  # Get the column names for input/target.
434
+ if not data_args.pretokenized:
435
+
436
+ dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
437
+ if data_args.text_column is None:
438
+ text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
439
+ else:
440
+ text_column = data_args.text_column
441
+ if text_column not in column_names:
442
+ raise ValueError(
443
+ f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
444
+ )
445
+ if data_args.summary_column is None:
446
+ summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
447
+ else:
448
+ summary_column = data_args.summary_column
449
+ if summary_column not in column_names:
450
+ raise ValueError(
451
+ f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
452
+ )
453
 
454
  # Temporarily set max_target_length for training.
455
  max_target_length = data_args.max_target_length