boris commited on
Commit
93c5ac8
1 Parent(s): bc78bfd

feat: remove hardcoded values

Browse files

Former-commit-id: aa6ade578bf4f7e28f9e529eda4efc2a34b86c89

Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +2 -6
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -376,9 +376,6 @@ def main():
376
  else:
377
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
378
 
379
- logger.warning(f"WARNING: eval_steps has been manually hardcoded") # TODO: remove it later, convenient for now
380
- training_args.eval_steps = 400
381
-
382
  if (
383
  os.path.exists(training_args.output_dir)
384
  and os.listdir(training_args.output_dir)
@@ -425,11 +422,10 @@ def main():
425
  # (the dataset will be downloaded automatically from the datasets Hub).
426
  #
427
  data_files = {}
428
- logger.warning(f"WARNING: Datasets path have been manually hardcoded") # TODO: remove it later, convenient for now
429
  if data_args.train_file is not None:
430
- data_files["train"] = ["/data/CC3M/training-encoded.tsv", "/data/CC12M/encoded-train.tsv", "/data/YFCC/metadata_encoded.tsv"]
431
  if data_args.validation_file is not None:
432
- data_files["validation"] = ["/data/CC3M/validation-encoded.tsv"]
433
  if data_args.test_file is not None:
434
  data_files["test"] = data_args.test_file
435
  dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
 
376
  else:
377
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
378
 
 
 
 
379
  if (
380
  os.path.exists(training_args.output_dir)
381
  and os.listdir(training_args.output_dir)
 
422
  # (the dataset will be downloaded automatically from the datasets Hub).
423
  #
424
  data_files = {}
 
425
  if data_args.train_file is not None:
426
+ data_files["train"] = data_args.train_file
427
  if data_args.validation_file is not None:
428
+ data_files["validation"] = data_args.validation_file
429
  if data_args.test_file is not None:
430
  data_files["test"] = data_args.test_file
431
  dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")