Upload sft_train.py with huggingface_hub
Browse files- sft_train.py +7 -3
sft_train.py
CHANGED
|
@@ -378,6 +378,7 @@ def load_sft_config(config_path):
|
|
| 378 |
cfg = {
|
| 379 |
"auto_config": raw.get("auto_config", True),
|
| 380 |
"hf_model_repo": raw.get("hf_model_repo", "ASTERIZER/LUNA-100M"),
|
|
|
|
| 381 |
"hf_dataset_repo": raw.get("hf_dataset_repo", "ASTERIZER/Luna_Dataset"),
|
| 382 |
"pretrained_ckpt": raw.get("pretrained_ckpt", "Base/out/pretrain/luna_100m/latest.pt"),
|
| 383 |
"train_json": raw.get("train_json", "Base/Datasets/sft_clean/train.json"),
|
|
@@ -467,15 +468,18 @@ def sft_train(cfg):
|
|
| 467 |
if not ckpt_path.exists() and cfg.get("hf_model_repo"):
|
| 468 |
# Auto-download from HuggingFace model repo
|
| 469 |
print(f"\n Pretrained checkpoint not found locally.")
|
| 470 |
-
print(f" Downloading from HuggingFace: {cfg['hf_model_repo']}")
|
| 471 |
from huggingface_hub import hf_hub_download
|
| 472 |
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
| 473 |
-
hf_hub_download(
|
| 474 |
repo_id=cfg["hf_model_repo"],
|
| 475 |
-
filename="
|
| 476 |
local_dir=str(ckpt_path.parent),
|
| 477 |
token=os.environ.get("HF_TOKEN"),
|
| 478 |
)
|
|
|
|
|
|
|
|
|
|
| 479 |
print(f" Downloaded to: {ckpt_path}")
|
| 480 |
|
| 481 |
if ckpt_path.exists():
|
|
|
|
| 378 |
cfg = {
|
| 379 |
"auto_config": raw.get("auto_config", True),
|
| 380 |
"hf_model_repo": raw.get("hf_model_repo", "ASTERIZER/LUNA-100M"),
|
| 381 |
+
"hf_model_file": raw.get("hf_model_file", "latest.pt"),
|
| 382 |
"hf_dataset_repo": raw.get("hf_dataset_repo", "ASTERIZER/Luna_Dataset"),
|
| 383 |
"pretrained_ckpt": raw.get("pretrained_ckpt", "Base/out/pretrain/luna_100m/latest.pt"),
|
| 384 |
"train_json": raw.get("train_json", "Base/Datasets/sft_clean/train.json"),
|
|
|
|
| 468 |
if not ckpt_path.exists() and cfg.get("hf_model_repo"):
|
| 469 |
# Auto-download from HuggingFace model repo
|
| 470 |
print(f"\n Pretrained checkpoint not found locally.")
|
| 471 |
+
print(f" Downloading from HuggingFace: {cfg['hf_model_repo']} ({cfg['hf_model_file']})")
|
| 472 |
from huggingface_hub import hf_hub_download
|
| 473 |
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
| 474 |
+
downloaded = hf_hub_download(
|
| 475 |
repo_id=cfg["hf_model_repo"],
|
| 476 |
+
filename=cfg["hf_model_file"],
|
| 477 |
local_dir=str(ckpt_path.parent),
|
| 478 |
token=os.environ.get("HF_TOKEN"),
|
| 479 |
)
|
| 480 |
+
downloaded_path = Path(downloaded)
|
| 481 |
+
if not ckpt_path.exists() and downloaded_path.exists():
|
| 482 |
+
ckpt_path = downloaded_path
|
| 483 |
print(f" Downloaded to: {ckpt_path}")
|
| 484 |
|
| 485 |
if ckpt_path.exists():
|