File size: 2,246 Bytes
206942c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
diff --git a/sample_factory/huggingface/huggingface_utils.py b/sample_factory/huggingface/huggingface_utils.py
index b6b10fc2..88ebd2c4 100644
--- a/sample_factory/huggingface/huggingface_utils.py
+++ b/sample_factory/huggingface/huggingface_utils.py
@@ -117,27 +117,40 @@ def push_to_hf(dir_path: str, repo_name: str, num_policies: int = 1):
         exist_ok=True,
     )
 
-    # Upload folders
-    folders = [".summary"]
-    for policy_id in range(num_policies):
-        folders.append(f"checkpoint_p{policy_id}")
-    for f in folders:
-        if os.path.exists(os.path.join(dir_path, f)):
-            upload_folder(
-                repo_id=repo_name,
-                folder_path=os.path.join(dir_path, f),
-                path_in_repo=f,
-            )
-
-    # Upload files
-    files = ["config.json", "README.md", "replay.mp4"]
-    for f in files:
-        if os.path.exists(os.path.join(dir_path, f)):
-            upload_file(
-                repo_id=repo_name,
-                path_or_fileobj=os.path.join(dir_path, f),
-                path_in_repo=f,
-            )
+    upload_folder(
+        repo_id=repo_name,
+        folder_path=dir_path,
+        path_in_repo=f,
+        allow_patterns=[
+            ".summary/*",
+            "config.json",
+            "README.md",
+            "replay.mp4",
+        ]
+        + [f"checkpoint_p{policy_id}/*" for policy_id in range(num_policies)],
+    )
+
+    # # Upload folders
+    # folders = [".summary"]
+    # for policy_id in range(num_policies):
+    #     folders.append(f"checkpoint_p{policy_id}")
+    # for f in folders:
+    #     if os.path.exists(os.path.join(dir_path, f)):
+    #         upload_folder(
+    #             repo_id=repo_name,
+    #             folder_path=os.path.join(dir_path, f),
+    #             path_in_repo=f,
+    #         )
+
+    # # Upload files
+    # files = ["config.json", "README.md", "replay.mp4"]
+    # for f in files:
+    #     if os.path.exists(os.path.join(dir_path, f)):
+    #         upload_file(
+    #             repo_id=repo_name,
+    #             path_or_fileobj=os.path.join(dir_path, f),
+    #             path_in_repo=f,
+    #         )
 
     log.info(f"The model has been pushed to {repo_url}")