leafspark commited on
Commit
e1cc7f1
1 Parent(s): d7bb05c

Add safetensors merge and split helper files

Browse files
Files changed (2) hide show
  1. merge.py +26 -0
  2. split.py +43 -0
merge.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import json
4
+
5
+ OUTPUT_FILE_NAME = "consolidated.safetensors" # Merge output file name
6
+ CHUNK_PATHS_FILE = "chunk_paths.json"
7
+
8
+ def merge(chunk_paths):
9
+ output_path = os.path.join(os.path.dirname(chunk_paths[0]), OUTPUT_FILE_NAME)
10
+
11
+ with open(output_path, "wb") as f_out:
12
+ for filepath in chunk_paths:
13
+ with open(filepath, "rb") as f_in:
14
+ f_out.write(f_in.read())
15
+
16
+ print(f"Merged file saved to {output_path}")
17
+
18
+ if __name__ == "__main__":
19
+
20
+ if os.path.exists(CHUNK_PATHS_FILE):
21
+ with open(CHUNK_PATHS_FILE) as f:
22
+ chunk_paths = json.load(f)
23
+ else:
24
+ chunk_paths = split(main_filepath)
25
+
26
+ merge(chunk_paths)
split.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import json
4
+
5
+ CHUNK_SIZE = 2 * 1024**3 # 40GB
6
+
7
+ CHUNK_PATHS_FILE = "chunk_paths.json"
8
+
9
+ def split(filepath, chunk_size=CHUNK_SIZE):
10
+ basename = os.path.basename(filepath)
11
+ dirname = os.path.dirname(filepath)
12
+ extension = basename.split(".")[-1]
13
+
14
+ filename_no_ext = basename.split(".")[-2]
15
+ file_size = os.path.getsize(filepath)
16
+
17
+ num_chunks = math.ceil(file_size / chunk_size)
18
+ digit_count = len(str(num_chunks))
19
+
20
+ chunk_paths = []
21
+
22
+ for i in range(1, num_chunks+1):
23
+ start = (i-1) * chunk_size
24
+
25
+ chunk_filename = f"{filename_no_ext}-{str(i).zfill(digit_count)}-of-{str(num_chunks).zfill(digit_count)}.{extension}"
26
+ split_path = os.path.join(dirname, chunk_filename)
27
+
28
+ with open(filepath, "rb") as f_in:
29
+ f_in.seek(start)
30
+ chunk = f_in.read(chunk_size)
31
+
32
+ with open(split_path, "wb") as f_out:
33
+ f_out.write(chunk)
34
+
35
+ chunk_paths.append(split_path)
36
+
37
+ with open(CHUNK_PATHS_FILE, 'w') as f:
38
+ json.dump(chunk_paths, f)
39
+
40
+ return chunk_paths
41
+
42
+ main_filepath = "consolidated.safetensors" # File to be split
43
+ chunk_paths = split(main_filepath)