winglian commited on
Commit
120e7df
1 Parent(s): 87e073d
Files changed (1) hide show
  1. scripts/finetune.py +3 -2
scripts/finetune.py CHANGED
@@ -427,9 +427,10 @@ def train(
427
  max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
428
  ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
429
  prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
 
430
  if any(prepared_ds_path.glob("*")):
431
  logging.info("Loading prepared dataset from disk...")
432
- dataset = load_from_disk(cfg.dataset_prepared_path)
433
  logging.info("Prepared dataset loaded from disk...")
434
  else:
435
  logging.info("Loading raw datasets...")
@@ -437,7 +438,7 @@ def train(
437
  for d in cfg.datasets:
438
  ds_from_hub = False
439
  try:
440
- ds = load_dataset(d.path, streaming=True)
441
  ds_from_hub = True
442
  except FileNotFoundError:
443
  pass
 
427
  max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
428
  ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
429
  prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
430
+
431
  if any(prepared_ds_path.glob("*")):
432
  logging.info("Loading prepared dataset from disk...")
433
+ dataset = load_from_disk(str(prepared_ds_path))
434
  logging.info("Prepared dataset loaded from disk...")
435
  else:
436
  logging.info("Loading raw datasets...")
 
438
  for d in cfg.datasets:
439
  ds_from_hub = False
440
  try:
441
+ load_dataset(d.path, streaming=True)
442
  ds_from_hub = True
443
  except FileNotFoundError:
444
  pass