diff --git a/sample_factory/huggingface/huggingface_utils.py b/sample_factory/huggingface/huggingface_utils.py index b6b10fc2..88ebd2c4 100644 --- a/sample_factory/huggingface/huggingface_utils.py +++ b/sample_factory/huggingface/huggingface_utils.py @@ -117,27 +117,40 @@ def push_to_hf(dir_path: str, repo_name: str, num_policies: int = 1): exist_ok=True, ) - # Upload folders - folders = [".summary"] - for policy_id in range(num_policies): - folders.append(f"checkpoint_p{policy_id}") - for f in folders: - if os.path.exists(os.path.join(dir_path, f)): - upload_folder( - repo_id=repo_name, - folder_path=os.path.join(dir_path, f), - path_in_repo=f, - ) - - # Upload files - files = ["config.json", "README.md", "replay.mp4"] - for f in files: - if os.path.exists(os.path.join(dir_path, f)): - upload_file( - repo_id=repo_name, - path_or_fileobj=os.path.join(dir_path, f), - path_in_repo=f, - ) + upload_folder( + repo_id=repo_name, + folder_path=dir_path, + path_in_repo=f, + allow_patterns=[ + ".summary/*", + "config.json", + "README.md", + "replay.mp4", + ] + + [f"checkpoint_p{policy_id}/*" for policy_id in range(num_policies)], + ) + + # # Upload folders + # folders = [".summary"] + # for policy_id in range(num_policies): + # folders.append(f"checkpoint_p{policy_id}") + # for f in folders: + # if os.path.exists(os.path.join(dir_path, f)): + # upload_folder( + # repo_id=repo_name, + # folder_path=os.path.join(dir_path, f), + # path_in_repo=f, + # ) + + # # Upload files + # files = ["config.json", "README.md", "replay.mp4"] + # for f in files: + # if os.path.exists(os.path.join(dir_path, f)): + # upload_file( + # repo_id=repo_name, + # path_or_fileobj=os.path.join(dir_path, f), + # path_in_repo=f, + # ) log.info(f"The model has been pushed to {repo_url}")