Upload run_summarization_flax.py
Browse files- run_summarization_flax.py +19 -17
run_summarization_flax.py
CHANGED
@@ -431,23 +431,25 @@ def main():
|
|
431 |
return
|
432 |
|
433 |
# Get the column names for input/target.
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
|
|
|
|
451 |
|
452 |
# Temporarily set max_target_length for training.
|
453 |
max_target_length = data_args.max_target_length
|
|
|
431 |
return
|
432 |
|
433 |
# Get the column names for input/target.
|
434 |
+
if not data_args.pretokenized:
|
435 |
+
|
436 |
+
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
|
437 |
+
if data_args.text_column is None:
|
438 |
+
text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
439 |
+
else:
|
440 |
+
text_column = data_args.text_column
|
441 |
+
if text_column not in column_names:
|
442 |
+
raise ValueError(
|
443 |
+
f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
|
444 |
+
)
|
445 |
+
if data_args.summary_column is None:
|
446 |
+
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
447 |
+
else:
|
448 |
+
summary_column = data_args.summary_column
|
449 |
+
if summary_column not in column_names:
|
450 |
+
raise ValueError(
|
451 |
+
f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
|
452 |
+
)
|
453 |
|
454 |
# Temporarily set max_target_length for training.
|
455 |
max_target_length = data_args.max_target_length
|