ASTERIZER commited on
Commit
0122e75
·
verified ·
1 Parent(s): 097c451

Upload sft_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sft_train.py +7 -3
sft_train.py CHANGED
@@ -378,6 +378,7 @@ def load_sft_config(config_path):
378
  cfg = {
379
  "auto_config": raw.get("auto_config", True),
380
  "hf_model_repo": raw.get("hf_model_repo", "ASTERIZER/LUNA-100M"),
 
381
  "hf_dataset_repo": raw.get("hf_dataset_repo", "ASTERIZER/Luna_Dataset"),
382
  "pretrained_ckpt": raw.get("pretrained_ckpt", "Base/out/pretrain/luna_100m/latest.pt"),
383
  "train_json": raw.get("train_json", "Base/Datasets/sft_clean/train.json"),
@@ -467,15 +468,18 @@ def sft_train(cfg):
467
  if not ckpt_path.exists() and cfg.get("hf_model_repo"):
468
  # Auto-download from HuggingFace model repo
469
  print(f"\n Pretrained checkpoint not found locally.")
470
- print(f" Downloading from HuggingFace: {cfg['hf_model_repo']}")
471
  from huggingface_hub import hf_hub_download
472
  ckpt_path.parent.mkdir(parents=True, exist_ok=True)
473
- hf_hub_download(
474
  repo_id=cfg["hf_model_repo"],
475
- filename="latest.pt",
476
  local_dir=str(ckpt_path.parent),
477
  token=os.environ.get("HF_TOKEN"),
478
  )
 
 
 
479
  print(f" Downloaded to: {ckpt_path}")
480
 
481
  if ckpt_path.exists():
 
378
  cfg = {
379
  "auto_config": raw.get("auto_config", True),
380
  "hf_model_repo": raw.get("hf_model_repo", "ASTERIZER/LUNA-100M"),
381
+ "hf_model_file": raw.get("hf_model_file", "latest.pt"),
382
  "hf_dataset_repo": raw.get("hf_dataset_repo", "ASTERIZER/Luna_Dataset"),
383
  "pretrained_ckpt": raw.get("pretrained_ckpt", "Base/out/pretrain/luna_100m/latest.pt"),
384
  "train_json": raw.get("train_json", "Base/Datasets/sft_clean/train.json"),
 
468
  if not ckpt_path.exists() and cfg.get("hf_model_repo"):
469
  # Auto-download from HuggingFace model repo
470
  print(f"\n Pretrained checkpoint not found locally.")
471
+ print(f" Downloading from HuggingFace: {cfg['hf_model_repo']} ({cfg['hf_model_file']})")
472
  from huggingface_hub import hf_hub_download
473
  ckpt_path.parent.mkdir(parents=True, exist_ok=True)
474
+ downloaded = hf_hub_download(
475
  repo_id=cfg["hf_model_repo"],
476
+ filename=cfg["hf_model_file"],
477
  local_dir=str(ckpt_path.parent),
478
  token=os.environ.get("HF_TOKEN"),
479
  )
480
+ downloaded_path = Path(downloaded)
481
+ if not ckpt_path.exists() and downloaded_path.exists():
482
+ ckpt_path = downloaded_path
483
  print(f" Downloaded to: {ckpt_path}")
484
 
485
  if ckpt_path.exists():