bugfixes
Browse files- scripts/finetune.py +3 -2
scripts/finetune.py
CHANGED
@@ -427,9 +427,10 @@ def train(
|
|
427 |
max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
|
428 |
ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
|
429 |
prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
|
|
430 |
if any(prepared_ds_path.glob("*")):
|
431 |
logging.info("Loading prepared dataset from disk...")
|
432 |
-
dataset = load_from_disk(
|
433 |
logging.info("Prepared dataset loaded from disk...")
|
434 |
else:
|
435 |
logging.info("Loading raw datasets...")
|
@@ -437,7 +438,7 @@ def train(
|
|
437 |
for d in cfg.datasets:
|
438 |
ds_from_hub = False
|
439 |
try:
|
440 |
-
|
441 |
ds_from_hub = True
|
442 |
except FileNotFoundError:
|
443 |
pass
|
|
|
427 |
max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
|
428 |
ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
|
429 |
prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
430 |
+
|
431 |
if any(prepared_ds_path.glob("*")):
|
432 |
logging.info("Loading prepared dataset from disk...")
|
433 |
+
dataset = load_from_disk(str(prepared_ds_path))
|
434 |
logging.info("Prepared dataset loaded from disk...")
|
435 |
else:
|
436 |
logging.info("Loading raw datasets...")
|
|
|
438 |
for d in cfg.datasets:
|
439 |
ds_from_hub = False
|
440 |
try:
|
441 |
+
load_dataset(d.path, streaming=True)
|
442 |
ds_from_hub = True
|
443 |
except FileNotFoundError:
|
444 |
pass
|